Compare commits

..

22 Commits

Author SHA1 Message Date
Daniel Hiltgen
1813ff85a0 cuda: bring back CC 5.2 (#12666)
Forward compat on the newer driver doesn't seem to be working.
This should get 5.2 working on newer drivers again.
2025-10-16 13:07:41 -07:00
Daniel Hiltgen
b531777a66 test: add a few missing embedding models (#12661) 2025-10-16 09:36:25 -07:00
Daniel Hiltgen
fe3ec8dbf0 Revert "Workaround broken NVIDIA iGPU free VRAM data (#12490)" (#12642)
The workaround has been moved into the underlying C++ code.

This reverts commit e4340667e3.
2025-10-16 09:09:48 -07:00
Thomas Stocker
c744134287 vulkan: Get FilterID from Backend for Vulkan (#12655)
* vulkan: Get FilterID from Backend for Vulkan

* Fixing patch
2025-10-16 09:07:35 -07:00
weedge
4be41d2d45 readme: add achatbot-go to community integrations (#12629) 2025-10-15 21:54:15 -07:00
zhetaicheleba
de670570c9 fs/ggml: fix function name in comment (#12630) 2025-10-15 21:53:38 -07:00
Devon Rifkin
201d93716e Merge pull request #12651 from ollama/drifkin/oai-conversion
openai: make tool call conversion fns public
2025-10-15 21:10:30 -07:00
Devon Rifkin
160cecc8e2 openai: make tool call conversion fns public 2025-10-15 20:54:58 -07:00
Daniel Hiltgen
8b6e5baee7 CI: Set up temporary opt-out Vulkan support (#12614)
Initially Vulkan support in Ollama will require building from source.  Once it is
more thoroughly tested and we have fixed any critical bugs, then we can
bundle Vulkan into the official binary releases.
2025-10-15 14:18:01 -07:00
Daniel Hiltgen
75d17fc6c2 perf: backport cuda iGPU sched spin (#12641) 2025-10-15 11:52:14 -07:00
Santosh Bhavani
8fafc8af77 ml/backend/ggml: NVML fallback for unified memory GPUs (#12619)
* Simplify NVML fallback for unified memory GPUs

Remove device-specific checks and environment variable dependency for
NVML_ERROR_NOT_SUPPORTED fallback. When NVML doesn't support memory
queries, unconditionally use /proc/meminfo instead of checking device
names or OLLAMA_UNIFIED_MEMORY environment variable.

This provides better memory reporting by using MemAvailable which
accounts for reclaimable memory, avoiding the underreporting issue
described in NVIDIA support article a_id/5728.

Tested on NVIDIA GB10 unified memory iGPU with consistent and accurate
memory reporting across multiple model load/unload cycles.

* Add NVML fallback patch for unified memory GPUs
2025-10-15 11:40:06 -07:00
Jesse Gross
c3c85aa06c llm: Enable flash attention by default for gemma3 2025-10-15 10:42:12 -07:00
Jeffrey Morgan
0d713051a2 envconfig: default to port 443 when connecting to ollama.com (#12617) 2025-10-14 23:38:24 -07:00
Parth Sareen
c4c5a4a01e types: send index for tool calls (#12625) 2025-10-14 19:35:15 -07:00
Jesse Gross
3dcfd5f69e llm: Perform eviction when num_gpu is set with new estimates
Currently, if you set num_gpu then this forces the model to
load with that number of layers in the current configuration.
This is done regardless of any other information, which means
that no eviction is performed even if another model is loaded.

This behavior is different from the old estimates (and still
happens for models that runs on the llama engine). In those
cases, models would be evicted if needed to load at the requested
number of layers. That behavior is more useful and less surprising,
so this changes the new estimates to match.

Fixes #12580
2025-10-14 17:46:36 -07:00
Devon Rifkin
53a969d509 Merge pull request #12621 from ollama/drifkin/any-of
qwen3-coder: support anyOf when parsing tool calls
2025-10-14 15:51:24 -07:00
Devon Rifkin
08fbb60bb2 qwen3-coder: support anyOf when parsing tool calls 2025-10-14 15:33:05 -07:00
Daniel Hiltgen
850da848c5 logs: fix bogus "0 MiB free" log line (#12590)
On the llama runner, after the recent GGML bump a new log line reports
incorrect 0 MiB free after our patch to remove memory from the props.  This
adjusts the llama.cpp code to fetch the actual free memory of the active device.
2025-10-14 11:26:28 -07:00
Thomas Stocker
2aba569a2a Vulkan based on #9650 (#11835)
* implement the vulkan C backend

* add support in gpu.go

* add support in gen_linux.sh

* it builds

* fix segfault

* fix compilation

* fix free memory monitor

* fix total memory monitor

* update gpu.go

* fix build

* fix check_perfmon len

* remove cap_get_bound check

* fix vulkan handle releasing

* fix build on federa 40

* fix vulkan on windows

* making amdgpu work on arm achitecutre with vulkan

* add x86_64 lines in VulkanGlobs and capLinuxGlobs

* add aarch64 lines in vulkanGlobs and capLinuxGlobs

* Fix variable name

* Add vulkan build patch from @jmorganca

* Sync vendored ggml to add Vulkan support

* Updated dockerfile

https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Installing rocm library

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* This version works well

built based on this: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Applied 00-fix-vulkan-building.patch

Work done by McBane87 here: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Fixed the "detached head" issues

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Merged in the right direction

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Merging the latest stable (#2)

* Applied 00-fix-vulkan-building.patch

* Implemented vulkan backend based on the work done by whyvl, Dts0, McBane87 and others

Tested on AMD Ryzen 7 8845HS w/ Radeon 780M Graphics with ROCm disabled

```
[GIN-debug] POST   /v1/chat/completions      --> github.com/ollama/ollama/server.(*Server).ChatHandler-fm (6 handlers)
[GIN-debug] POST   /v1/completions           --> github.com/ollama/ollama/server.(*Server).GenerateHandler-fm (6 handlers)
[GIN-debug] POST   /v1/embeddings            --> github.com/ollama/ollama/server.(*Server).EmbedHandler-fm (6 handlers)
[GIN-debug] GET    /v1/models                --> github.com/ollama/ollama/server.(*Server).ListHandler-fm (6 handlers)
[GIN-debug] GET    /v1/models/:model         --> github.com/ollama/ollama/server.(*Server).ShowHandler-fm (6 handlers)
time=2025-03-11T13:00:40.793Z level=INFO source=gpu.go:199 msg="vulkan: load libvulkan and libcap ok"
time=2025-03-11T13:00:40.877Z level=INFO source=gpu.go:421 msg="error looking up vulkan GPU memory" error="device is a CPU"
time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:443 msg="amdgpu detected, but no compatible rocm library found.  Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install"
time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:348 msg="unable to verify rocm library: no suitable rocm found, falling back to CPU"
time=2025-03-11T13:00:40.879Z level=INFO source=types.go:137 msg="inference compute" id=0 library=vulkan variant="" compute=1.3 driver=1.3 name="AMD Radeon Graphics (RADV GFX1103_R1)" total="15.6 GiB" available="15.6 GiB"
```

```
 # ollama run phi4:14b
>>> /set verbose
Set 'verbose' mode.
>>> how's it going?
Hello! I'm here to help you with any questions or tasks you have. How can I assist you today? 😊

total duration:       3.341959745s
load duration:        18.165612ms
prompt eval count:    15 token(s)
prompt eval duration: 475ms
prompt eval rate:     31.58 tokens/s
eval count:           26 token(s)
eval duration:        2.846s
eval rate:            9.14 tokens/s
>>>
```

* This is no longer needed

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Fixes SIGSEGV: segmentation violation running gemma3 models on ollama 0.6.0 #21

Patch provided by McBane87 on https://github.com/whyvl/ollama-vulkan/issues/21

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Applied 04-disable-mmap-vulkan.patch

From: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Pulled new upstream code for ggml-bulkan backend

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Merged latest ollama 0.6.2 and nasrally's Flash Attention patches (#5)

* readme: add Ellama to list of community integrations (#9800)

* readme: add screenpipe to community integrations (#9786)

* Add support for ROCm gfx1151 (#9773)

* conditionally enable parallel pipelines

* sample: make mutations in transforms explicit (#9743)

* updated minP to use early exit making use of sorted tokens

* ml/backend/ggml: allocate memory with malloc when loading model (#9822)

* runner: remove cache prompt flag from ollama runner (#9826)

We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.

* ollamarunner: Check for minBatch of context space when shifting

Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.

* Applied latest patches from McBane87

See this for details: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2708820861

Signed-off-by: Vadim Grinco <vadim@grinco.eu>

* Add ability to enable flash attention on vulkan (#4)

* discover: add flash attention handling for vulkan
* envconfig: fix typo in config.go

As part of the process some code was refactored and I added a new field
FlashAttention to GpuInfo since the previous solution didn't allow for a
granular check via vulkan extensions. As a side effect, this now allows
for granular per-device FA support checking in other places

---------

Signed-off-by: Vadim Grinco <vadim@grinco.eu>
Co-authored-by: zeo <108888572+zeozeozeo@users.noreply.github.com>
Co-authored-by: Louis Beaumont <louis.beaumont@gmail.com>
Co-authored-by: Daniel Hiltgen <dhiltgen@users.noreply.github.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Nikita <50599445+nasrally@users.noreply.github.com>

* Revert Readme changes

* Revert

* Revert changes in amd_linux.go

* Revert changes in amd_linux.go

* Remove flashattention setting gpu.go

* Revert whitespace changes in gpu.go

* Revert changes in transforms_test.go

* Revert changes in runner.go

* Revert changes in Makefile.sync

* Revert some unintented changes in Dockerfile

* Revert vulkan copy changes in Dockerfile

* Update Vulkan Code to de4c07f93783a1a96456a44dc16b9db538ee1618

* Fixed duplicate sync in ggml.go

* Revert changes in ggml.go

* Revert chnages in ggml.go

* enable falsh attention on vulkan

* revert remove parenthesis

* fixed flash attention logic enabling

* vk_check_flash_attention 0 means supported

* Update gpu.go

* Add vulkan to Windows Build script

* Remove commented out code

* Enable Vulkan Flash attention in FlashAttentionSupported

* Fix logging

* Update Vulkan backend to e54d41befcc1575f4c898c5ff4ef43970cead75f

* Removed libcap related code

libcap is not directly related to Vulkan and should be added by its own PR. It adds additional library dependencies for building and also requires users to run setcap or run ollama as root, which is not ideal for easy use

* Fix Unit Test (Add Vulkan Library)

* Add vulkan to TestHomogeneousGPUs
Test

* vulkan: get GPU ID (ollama v0.11.5)

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>

* disable mmap for vulkan

* Reduce Changes remove TestHomogeneousGPUs (doesn't exist on master)

* Update vulkan version to the version used in llama.cpp

* rename gpu patch to correct number

* added Vulkan API to get correct Device UUID

current UUID from pipelineCacheUUID does not match CUDA

* Fix GPU ID Patch

* Remove Code not in llama.cpp

* modified UUID code inside ggml

* Fix Patch

* Copied minimal definition from vulkan header

* Fix compile error in Mac

Metal is preferred so we're disabling Vulkan for now

* Removed unused code

Fix linter error in CI

* Fix patches apply

* fixing lint error

* Removed unneeded function call

Somehow removing this call fixed the crashing when Vulkan header was removed

* added missing NL

* Fixed missing members in Vulkan header

also added zero clear for some structs

* Fixed wrong structure ID

* Fixed Vulkan header

More aligned with official header definition now

* buildvulkanAsSeperateFunction

* Vulkan on Windows Test

* temporarly comment out gate to run windows task

* use temporarly windows-latest for build

* Commenting out other presets to build vulkan

* reenable cpu

* commenting out error action stop

* temporarly commenting out rocm

* set vulkan path

* comment out cude for faster turnaround

* correct vulkan install

* correct vulkan silent install

* fixed install command

* revert debugging changes (vulkan builds on windows)

* revert windows-latest

* trying to build vulkan for linux

* temporarly disable cuda and rocm

* try again linux build

* fix version

* trying to fix

* trying again

* trying again

* fix version

* fixed vulkan-sdk name

* try again

* trying again

* try without version number

* try again

* add some more extra

* trying to use version 1.4.313

* revert debugging changes

* Filter out already supported gpus

* revert debug code

* Use runners for GPU discovery

This revamps how we discover GPUs in the system by leveraging the Ollama
runner.  This should eliminate inconsistency between our GPU discovery and the
runners capabilities at runtime, particularly for cases where we try to filter
out unsupported GPUs.  Now the runner does that implicitly based on the actual
device list.  In some cases free VRAM reporting can be unreliable which can
leaad to scheduling mistakes, so this also includes a patch to leverage more
reliable VRAM reporting libraries if available.

Automatic workarounds have been removed as only one GPU leveraged this, which
is now documented. This GPU will soon fall off the support matrix with the next
ROCm bump.

Additional cleanup of the scheduler and discovery packages can be done in the
future once we have switched on the new memory management code, and removed
support for the llama runner.

* timing info for runner

* WIP - wire up Vulkan with the new engine based discovery

Not a complete implementation - free VRAM is better, but not accurate on
windows

* fix - trust the library paths from discovery when starting runner

* fix index bug

* fix vulkan ids to be underlying

* fix - give bootstrapping more time on slow systems

* Test if Vulkan device is supported

* vk_check_flash_attention is not needed (coompat2 coopmapt and scalar implementation exist)

* Handle GGML_VK_VISIBLE_DEVICES

* ask for supported first

* win: fix CPU query buffer handling

Try in a short loop until we get the size right.

* test: harden integration tests for slow start

If the server takes a while to start up, block
tests from starting until it's online to avoid
setting large timeouts in individual test cases.

* gofumpt fix

* fix build

* merge fixes

* merge fixes

* fixed build

* merge fixes

* fixing build

* fixed build

* fixed formatting

* fixed build

* fix vulkan gpu id patch

* sync llama.cpp vulkan code

* update build windows script

* merge fixes

* fix format

* fixed vulkan casing

* handle igpu as gpu

* improve case

* print out unknown library

* rturn Vulkan for vulkan library

* Revert "rturn Vulkan for vulkan library"

This reverts commit 690461a12f.

* fixed patch number

* return Library Name

* remvoe debug code

* return integrated in vulkan backend

* Return pci Properties

* update patch

* directly get pci proeprties without parsing

* workaround for filtering devices. Correct way is to have a LibraryPosition Parameter in the deviceInfo

* Revert "directly get pci proeprties without parsing"

This reverts commit 8e0624851f.

* Set FilteredID for Environment Filtering

* ROCm Library is named ROCm

* revert changes in patch

* Create 0028-vulkan-pci-and-memory.patch

* vulkan memory patch

* casing fix

* Add more pci properties

* Added better memory management

* Added better memory managament

* fixed patch

* Fixed patch

* FilterID creation group by library

* filter out vulkan supported by other gpu

* fixing deviceid compare

* Vulkan Fix FA coopmat1 invalid array indexing

* Use everywhere the same Vulkan Version 1.4.321.1

* Remove unneeded patch

* vulkan update

* sync vulkan glsl files

* only use for vulkan the filteredid (numeric device number)

* simplify code

---------

Signed-off-by: Vadim Grinco <vadim@grinco.eu>
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: pufferffish <github@bandersnatch.anonaddy.com>
Co-authored-by: KOISHI KOMEIJI FROM TOUHOU 11 <fuck>
Co-authored-by: DSLstandard <qgeneral35@gmail.com>
Co-authored-by: pufferffish <me@windtfw.com>
Co-authored-by: yeongbba <yeongmo.lee@logpresso.com>
Co-authored-by: tomaThomas <tomathomas@mailbox.org>
Co-authored-by: Antoine Viallon <antoine@lesviallon.fr>
Co-authored-by: Vadim Grinco <vadim@grinco.eu>
Co-authored-by: zeo <108888572+zeozeozeo@users.noreply.github.com>
Co-authored-by: Louis Beaumont <louis.beaumont@gmail.com>
Co-authored-by: Daniel Hiltgen <dhiltgen@users.noreply.github.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Nikita <50599445+nasrally@users.noreply.github.com>
Co-authored-by: Masato Nakasaka <masato.nakasaka@intel.com>
Co-authored-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
2025-10-14 10:59:58 -07:00
Devon Rifkin
fd8aa947f3 Merge pull request #12562 from ollama/drifkin/registries
add registries for parsers/renderers
2025-10-14 02:01:53 -07:00
Devon Rifkin
ddaca643d0 add registries for parsers/renderers 2025-10-14 01:13:54 -07:00
Grace
05982a95cb Qwen3VL Cloud Parser and Renderer (#12526)
* working (other than tool call is the incorrect order) for tool calls and tools

* Tests work, other than image tags (tests do not go through server) and tools (not in the correct order, but contents are the same)

* testing for qwen3vl parser - toolparser is working

* made changes to JSON tool parser, wraps the TollCallFunction with a TollCall object

* Working parser for thinking models - assumes state of thinking, emits unambiguous content in thinking, does not call tool call in thinking

* changed the parser to start with collecting content

* thinking prefill

* add hasThinkingSupport parameter to parser

* qwen3-vl -> qwen3-vl-instruct for renderer/parser

* Add hasThinkingSupport=false to QwenVLParser

---------

Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2025-10-13 16:52:33 -07:00
187 changed files with 32832 additions and 114 deletions

View File

@@ -237,13 +237,13 @@ jobs:
include:
- os: linux
arch: amd64
target: archive
target: archive_novulkan
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive
target: archive_novulkan
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -299,12 +299,14 @@ jobs:
include:
- os: linux
arch: arm64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
@@ -317,6 +319,14 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
- os: linux
arch: amd64
suffix: '-vulkan'
target: default
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -334,6 +344,7 @@ jobs:
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.target }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest

View File

@@ -52,6 +52,12 @@ jobs:
container: rocm/dev-ubuntu-22.04:6.1.2
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan
container: ubuntu:22.04
extra-packages: >
mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make
runs-on: linux
container: ${{ matrix.container }}
steps:
@@ -59,7 +65,19 @@ jobs:
- run: |
[ -n "${{ matrix.container }}" ] || sudo=sudo
$sudo apt-get update
# Add LunarG Vulkan SDK apt repo for Ubuntu 22.04
if [ "${{ matrix.preset }}" = "Vulkan" ]; then
$sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg
# Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY
echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null
$sudo apt-get update
fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
fi
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/cache@v4
@@ -92,18 +110,21 @@ jobs:
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm'
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: matrix.preset == 'CUDA'
name: Install CUDA ${{ matrix.cuda-version }}
@@ -133,6 +154,18 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:

View File

@@ -139,3 +139,15 @@ if(CMAKE_HIP_COMPILER)
endforeach()
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()

View File

@@ -30,7 +30,7 @@
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_ARCHITECTURES": "50;52;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
@@ -70,6 +70,10 @@
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
},
{
"name": "Vulkan",
"inherits": [ "Default" ]
}
],
"buildPresets": [
@@ -122,6 +126,11 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 6"
},
{
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
}
]
}

View File

@@ -7,6 +7,7 @@ ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
ARG VULKANVERSION=1.4.321.1
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
@@ -17,6 +18,16 @@ RUN yum install -y yum-utils \
&& dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
@@ -106,6 +117,13 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
COPY go.mod go.sum .
@@ -123,7 +141,8 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -136,14 +155,40 @@ FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
# Temporary opt-out stages for Vulkan
FROM --platform=linux/amd64 scratch AS amd64_novulkan
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
FROM arm64 AS arm64_novulkan
FROM ${FLAVOR}_novulkan AS archive_novulkan
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 AS novulkan
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive_novulkan /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434
EXPOSE 11434
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM ubuntu:24.04 AS default
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive /lib/ollama /usr/lib/ollama

View File

@@ -544,6 +544,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot.
### Mobile

View File

@@ -204,7 +204,7 @@ type ToolCall struct {
}
type ToolCallFunction struct {
Index int `json:"index,omitempty"`
Index int `json:"index"`
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
@@ -266,9 +266,9 @@ func (pt PropertyType) String() string {
type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type"`
Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"`
Description string `json:"description"`
Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"`
}
@@ -332,7 +332,7 @@ func (t *ToolFunctionParameters) String() string {
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Description string `json:"description,omitempty"`
Parameters ToolFunctionParameters `json:"parameters"`
}

View File

@@ -298,6 +298,30 @@ func TestToolFunction_UnmarshalJSON(t *testing.T) {
}
}
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",
Arguments: ToolCallFunctionArguments{"message": "hi"},
}
data, err := json.Marshal(fn)
require.NoError(t, err)
raw := map[string]any{}
require.NoError(t, json.Unmarshal(data, &raw))
require.Contains(t, raw, "index")
assert.Equal(t, float64(0), raw["index"])
fn.Index = 3
data, err = json.Marshal(fn)
require.NoError(t, err)
raw = map[string]any{}
require.NoError(t, json.Unmarshal(data, &raw))
require.Contains(t, raw, "index")
assert.Equal(t, float64(3), raw["index"])
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string

View File

@@ -473,34 +473,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
local := opts.Copy()
if local.MultiModal {
prompt, images, err := extractFileData(local.Prompt)
if err != nil {
return err
}
local.Prompt = prompt
local.Images = images
}
messages := slices.Clone(local.Messages)
if local.System != "" {
hasSystem := len(messages) > 0 && messages[0].Role == "system"
if !hasSystem {
messages = append([]api.Message{{Role: "system", Content: local.System}}, messages...)
}
}
if local.Prompt != "" || len(local.Images) > 0 {
messages = append(messages, api.Message{Role: "user", Content: local.Prompt, Images: local.Images})
}
local.Messages = messages
_, err = chat(cmd, local)
return err
return generate(cmd, opts)
}
func SigninHandler(cmd *cobra.Command, args []string) error {
@@ -1126,6 +1099,8 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return client.Pull(cmd.Context(), &request, fn)
}
type generateContextKey string
type runOptions struct {
Model string
ParentModel string
@@ -1388,6 +1363,137 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
var latest api.GenerateResponse
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
if !ok {
generateContext = []int{}
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var thinkingContent strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
fn := func(response api.GenerateResponse) error {
latest = response
content := response.Response
if response.Response != "" || !opts.HideThinking {
p.StopAndClear()
}
if response.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(plainText))
thinkTagOpened = true
thinkTagClosed = false
}
thinkingContent.WriteString(response.Thinking)
displayResponse(response.Thinking, opts.WordWrap, state)
}
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) {
if !strings.HasSuffix(thinkingContent.String(), "\n") {
fmt.Println()
}
fmt.Print(thinkingOutputClosingText(plainText))
thinkTagOpened = false
thinkTagClosed = true
state = &displayResponseState{}
}
displayResponse(content, opts.WordWrap, state)
if response.ToolCalls != nil {
toolCalls := response.ToolCalls
if len(toolCalls) > 0 {
fmt.Print(renderToolCalls(toolCalls, plainText))
}
}
return nil
}
if opts.MultiModal {
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
if err != nil {
return err
}
}
if opts.Format == "json" {
opts.Format = `"` + opts.Format + `"`
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Images: opts.Images,
Format: json.RawMessage(opts.Format),
System: opts.System,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
Think: opts.Think,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
if opts.Prompt != "" {
fmt.Println()
fmt.Println()
}
if !latest.Done {
return nil
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
}
if verbose {
latest.Summary()
}
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
}
func RunServer(_ *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil {
return err

View File

@@ -70,6 +70,7 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
if dev.Library == "ROCm" && rocmDir != "" {
info.DependencyPath = append(info.DependencyPath, rocmDir)
}
// TODO any special processing of Vulkan devices?
resp = append(resp, info)
}
if len(resp) == 0 {
@@ -97,7 +98,16 @@ func (l GpuInfoList) GetVisibleDevicesEnv() []string {
if len(l) == 0 {
return nil
}
return []string{rocmGetVisibleDevicesEnv(l)}
res := []string{}
envVar := rocmGetVisibleDevicesEnv(l)
if envVar != "" {
res = append(res, envVar)
}
envVar = vkGetVisibleDevicesEnv(l)
if envVar != "" {
res = append(res, envVar)
}
return res
}
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
@@ -127,6 +137,25 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
return envVar + strings.Join(ids, ",")
}
func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "Vulkan" {
continue
}
if info.filterID != "" {
ids = append(ids, info.filterID)
} else {
ids = append(ids, info.ID)
}
}
if len(ids) == 0 {
return ""
}
envVar := "GGML_VK_VISIBLE_DEVICES="
return envVar + strings.Join(ids, ",")
}
// GetSystemInfo returns the last cached state of the GPUs on the system
func GetSystemInfo() SystemInfo {
deviceMu.Lock()

View File

@@ -86,6 +86,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
// are enumerated, but not actually supported.
// We run this in serial to avoid potentially initializing a GPU multiple
// times concurrently leading to memory contention
// TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs
for dir := range libDirs {
var dirs []string
if dir != "" {
@@ -131,19 +132,25 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
go func(i int) {
defer wg.Done()
var envVar string
id := devices[i].ID
if devices[i].Library == "ROCm" {
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
} else {
envVar = "ROCR_VISIBLE_DEVICES"
}
} else {
} else if devices[i].Library == "CUDA" {
envVar = "CUDA_VISIBLE_DEVICES"
} else if devices[i].Library == "Vulkan" {
id = devices[i].FilteredID
envVar = "GGML_VK_VISIBLE_DEVICES"
} else {
slog.Error("Unknown Library:" + devices[i].Library)
}
extraEnvs := []string{
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
envVar + "=" + devices[i].ID, // Filter to just this one GPU
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
envVar + "=" + id, // Filter to just this one GPU
}
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
needsDelete[i] = true
@@ -163,6 +170,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
wg.Wait()
logutil.Trace("supported GPU library combinations", "supported", supported)
filterOutVulkanThatAreSupportedByOtherGPU(needsDelete)
// Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible
filterOverlapByLibrary(supported, needsDelete)
@@ -184,7 +193,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
}
// Now filter out any overlap with different libraries (favor CUDA/ROCm over others)
// Now filter out any overlap with different libraries (favor CUDA/HIP over others)
for i := 0; i < len(devices); i++ {
for j := i + 1; j < len(devices); j++ {
// For this pass, we only drop exact duplicates
@@ -340,12 +349,40 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
}
// Apply any iGPU workarounds
iGPUWorkarounds(devices)
return devices
}
func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) {
// Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion
for i := range devices {
if devices[i].Library != "Vulkan" || needsDelete[i] {
continue
}
if devices[i].PCIID == "" {
continue
}
for j := range devices {
if i == j {
continue
}
if devices[j].PCIID == "" {
continue
}
if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] {
needsDelete[i] = true
slog.Debug("dropping Vulkan duplicate by PCI ID",
"vulkan_id", devices[i].ID,
"vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1],
"pci_id", devices[i].PCIID,
"kept_library", devices[j].Library,
"kept_id", devices[j].ID,
)
break
}
}
}
}
func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) {
// For multi-GPU systems, use the newest version that supports all the GPUs
for _, byLibDirs := range supported {
@@ -451,6 +488,7 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
}
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
pathNeeded := true
@@ -508,6 +546,7 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
}
}
logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices)
return devices
}
@@ -559,32 +598,3 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceIn
}
}
}
func iGPUWorkarounds(devices []ml.DeviceInfo) {
// short circuit if we have no iGPUs
anyiGPU := false
for i := range devices {
if devices[i].Integrated {
anyiGPU = true
break
}
}
if !anyiGPU {
return
}
memInfo, err := GetCPUMem()
if err != nil {
slog.Debug("failed to fetch system memory information for iGPU", "error", err)
return
}
for i := range devices {
if !devices[i].Integrated {
continue
}
// NVIDIA iGPUs return useless free VRAM data which ignores system buff/cache
if devices[i].Library == "CUDA" {
devices[i].FreeMemory = memInfo.FreeMemory
}
}
}

View File

@@ -37,7 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
UnreliableFreeMemory bool
// GPU information
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
ComputeMinor int `json:"compute_minor"`
@@ -175,7 +175,8 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
gpu.Library == "ROCm"
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
if !supportsFA {
return false

View File

@@ -24,6 +24,9 @@ func Host() *url.URL {
switch {
case !ok:
scheme, hostport = "http", s
if s == "ollama.com" {
scheme, hostport = "https", "ollama.com:443"
}
case scheme == "http":
defaultPort = "80"
case scheme == "https":
@@ -217,6 +220,7 @@ var (
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES")
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
)
@@ -307,6 +311,7 @@ func AsMap() map[string]EnvVar {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"}
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"}
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}

View File

@@ -37,6 +37,7 @@ func TestHost(t *testing.T) {
"https": {"https://1.2.3.4", "https://1.2.3.4:443"},
"https port": {"https://1.2.3.4:4321", "https://1.2.3.4:4321"},
"proxy path": {"https://example.com/ollama", "https://example.com:443/ollama"},
"ollama.com": {"ollama.com", "https://ollama.com:443"},
}
for name, tt := range cases {

View File

@@ -893,6 +893,7 @@ func (f GGML) SupportsFlashAttention() bool {
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gemma3",
"gptoss", "gpt-oss",
"qwen3",
"qwen3moe",

View File

@@ -229,7 +229,7 @@ const (
TensorTypeMXFP4
)
// ParseFileType parses the provided GGUF file type
// ParseTensorType parses the provided GGUF tensor type
// Only Ollama supported types are considered valid
func ParseTensorType(s string) (TensorType, error) {
switch s {

File diff suppressed because one or more lines are too long

View File

@@ -267,10 +267,12 @@ static struct llama_model * llama_model_load_from_file_impl(
for (auto * dev : model->devices) {
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
size_t memory_free, memory_total;
ggml_backend_dev_memory(dev, &memory_free, &memory_total);
LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__,
ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
props.device_id ? props.device_id : "unknown id",
props.memory_free/1024/1024);
memory_free/1024/1024);
}
const int status = llama_model_load(path_model, splits, *model, params);

View File

@@ -69,7 +69,9 @@ func EnumerateGPUs() []ml.DeviceID {
for i := range C.ggml_backend_dev_count() {
device := C.ggml_backend_dev_get(i)
if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
switch C.ggml_backend_dev_type(device) {
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(device, &props)
ids = append(ids, ml.DeviceID{

View File

@@ -12,10 +12,11 @@ unused then it can be reset to free these data structures.
ggml/src/ggml-backend.cpp | 8 ++++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 16 +++++++++++++++-
ggml/src/ggml-cuda/vendors/hip.h | 1 +
5 files changed, 29 insertions(+), 1 deletion(-)
src/llama.cpp | 4 +++-
6 files changed, 32 insertions(+), 2 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 1ff53ed0..ba181d09 100644
index 1ff53ed03..ba181d09d 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -178,6 +178,7 @@ extern "C" {
@@ -27,7 +28,7 @@ index 1ff53ed0..ba181d09 100644
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 3c3f22fc..43c91d9f 100644
index 3c3f22fc0..43c91d9f2 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -195,6 +195,10 @@ extern "C" {
@@ -42,7 +43,7 @@ index 3c3f22fc..43c91d9f 100644
struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 6ef5eeaf..0b757af5 100644
index 6ef5eeafa..0b757af59 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
@@ -61,7 +62,7 @@ index 6ef5eeaf..0b757af5 100644
GGML_ASSERT(device);
return device->iface.get_buffer_type(device);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 811462c7..87c6c34a 100644
index 811462c79..87c6c34a4 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -107,6 +107,11 @@ int ggml_cuda_get_device() {
@@ -109,7 +110,7 @@ index 811462c7..87c6c34a 100644
// backend reg
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 890c1036..1f06be80 100644
index 890c10364..1f06be80e 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -45,6 +45,7 @@
@@ -120,3 +121,21 @@ index 890c1036..1f06be80 100644
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
diff --git a/src/llama.cpp b/src/llama.cpp
index fe5a7a835..d821a96a0 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -267,10 +267,12 @@ static struct llama_model * llama_model_load_from_file_impl(
for (auto * dev : model->devices) {
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
+ size_t memory_free, memory_total;
+ ggml_backend_dev_memory(dev, &memory_free, &memory_total);
LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__,
ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
props.device_id ? props.device_id : "unknown id",
- props.memory_free/1024/1024);
+ memory_free/1024/1024);
}
const int status = llama_model_load(path_model, splits, *model, params);

View File

@@ -22,7 +22,7 @@ diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index ba181d09..09ff75f9 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -169,6 +169,15 @@ extern "C" {
@@ -169,6 +169,17 @@ extern "C" {
const char * device_id;
// device capabilities
struct ggml_backend_dev_caps caps;
@@ -35,6 +35,8 @@ index ba181d09..09ff75f9 100644
+ int pci_device_id;
+ int pci_domain_id;
+ const char *library;
+ // number with which the devices are accessed (Vulkan)
+ const char *numeric_id;
};
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);

View File

@@ -0,0 +1,95 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Xiaodong Ye <xiaodong.ye@mthreads.com>
Date: Mon, 18 Aug 2025 12:48:07 +0800
Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++
1 file changed, 37 insertions(+)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 061cd078..adea7783 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
snprintf(description, description_size, "%s", props.deviceName.data());
}
+static std::string ggml_vk_get_device_id(int device) {
+ ggml_vk_instance_init();
+
+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ vk::PhysicalDeviceProperties2 props;
+ vk::PhysicalDeviceIDProperties deviceIDProps;
+ props.pNext = &deviceIDProps;
+ devices[device].getProperties2(&props);
+
+ const auto& uuid = deviceIDProps.deviceUUID;
+ char id[64];
+ snprintf(id, sizeof(id),
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+ uuid[0], uuid[1], uuid[2], uuid[3],
+ uuid[4], uuid[5],
+ uuid[6], uuid[7],
+ uuid[8], uuid[9],
+ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15]
+ );
+ return std::string(id);
+}
+
// backend interface
#define UNUSED GGML_UNUSED
@@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
ggml_vk_get_device_description(dev_idx, description, description_size);
}
+std::string ggml_backend_vk_get_device_id(int device) {
+ GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+ int dev_idx = vk_instance.device_indices[device];
+ return ggml_vk_get_device_id(dev_idx);
+}
+
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context {
std::string description;
bool is_integrated_gpu;
std::string pci_bus_id;
+ std::string id;
};
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
return ctx->description.c_str();
}
+static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ return ctx->id.c_str();
+}
+
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
ggml_backend_vk_get_device_memory(ctx->device, free, total);
@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev);
+ props->id = ggml_backend_vk_device_get_id(dev);
props->type = ggml_backend_vk_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+ ctx->id = ggml_backend_vk_get_device_id(i);
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
--
2.51.0

View File

@@ -0,0 +1,254 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri Sep 5 08:25:03 2025 -0700
Subject: [PATCH] Vulkan PCI and Memory
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++-----
1 file changed, 145 insertions(+), 31 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index adea7783..fb7204ce 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) {
return ggml_vk_get_device_id(dev_idx);
}
-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
- GGML_ASSERT(device < (int) vk_instance.device_indices.size());
- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
+//////////////////////////
+
+struct ggml_backend_vk_device_context {
+ size_t device;
+ std::string name;
+ std::string description;
+ bool is_integrated_gpu;
+ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function)
+ std::string pci_id;
+ std::string id;
+ std::string uuid;
+ int major;
+ int minor;
+ int driver_major;
+ int driver_minor;
+ int pci_bus_id;
+ int pci_device_id;
+ int pci_domain_id;
+};
+
+void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) {
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size());
+
+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
- vk::PhysicalDeviceMemoryProperties2 memprops = {};
- bool membudget_supported = vk_instance.device_supports_membudget[device];
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+ vk::PhysicalDeviceProperties2 props2;
+ vkdev.getProperties2(&props2);
- if (membudget_supported) {
- memprops.pNext = &budgetprops;
+ if (!ctx->is_integrated_gpu)
+ {
+ // Use vendor specific management libraries for best VRAM reporting if available
+ switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
+ ggml_hip_mgmt_release();
+ }
+ break;
+ case VK_VENDOR_ID_NVIDIA:
+ if (ggml_nvml_init() == 0) {
+ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+ ggml_nvml_release();
+ return;
+ }
+ ggml_nvml_release();
+ }
+ break;
+ }
}
- vkdev.getMemoryProperties2(&memprops);
+ // else fallback to memory budget if supported
- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
+ *total = 0;
+ *free = 0;
+ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
+ vk::PhysicalDeviceMemoryProperties2 memprops2;
+ memprops2.pNext = &mem_budget_props;
+ vkdev.getMemoryProperties2(&memprops2);
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ } else if (ctx->is_integrated_gpu) {
+ // Include shared memory on iGPUs
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ }
+ }
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *free += mem_budget_props.heapBudget[i];
+ } else if (ctx->is_integrated_gpu) {
+ *free += mem_budget_props.heapBudget[i];
+ }
+ }
+ if (*total > 0 && *free > 0) {
+ return;
+ } else if (*total > 0) {
+ *free = *total;
+ return;
+ }
+ // else just report the physical memory
+ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*total = heap.size;
-
- if (membudget_supported && i < budgetprops.heapUsage.size()) {
- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
- } else {
- *free = heap.size;
- }
+ *free = heap.size;
break;
}
}
@@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
return std::string(pci_bus_id);
}
-//////////////////////////
-
-struct ggml_backend_vk_device_context {
- size_t device;
- std::string name;
- std::string description;
- bool is_integrated_gpu;
- std::string pci_bus_id;
- std::string id;
-};
+static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) {
+ if (id.empty()) return false;
+ unsigned int d = 0, b = 0, dev = 0, func = 0;
+ // Expected format: dddd:bb:dd.f (all hex)
+ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func);
+ if (n < 4) return false;
+ if (domain) *domain = (int) d;
+ if (bus) *bus = (int) b;
+ if (device) *device = (int) dev;
+ return true;
+}
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
- ggml_backend_vk_get_device_memory(ctx->device, free, total);
+ ggml_backend_vk_get_device_memory(ctx, free, total);
}
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->description = ggml_backend_vk_device_get_description(dev);
props->id = ggml_backend_vk_device_get_id(dev);
props->type = ggml_backend_vk_device_get_type(dev);
- props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
+ props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
@@ -12564,6 +12633,17 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
+
+ props->compute_major = ctx->major;
+ props->compute_minor = ctx->minor;
+ props->driver_major = ctx->driver_major;
+ props->driver_minor = ctx->driver_minor;
+ props->integrated = ctx->is_integrated_gpu;
+ props->pci_bus_id = ctx->pci_bus_id;
+ props->pci_device_id = ctx->pci_device_id;
+ props->pci_domain_id = ctx->pci_domain_id;
+ props->library = GGML_VK_NAME;
+ props->numeric_id = ctx->id.empty() ? nullptr : ctx->id.c_str();
}
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
+ std::vector<vk::PhysicalDevice> vk_devices = vk_instance.instance.enumeratePhysicalDevices();
+
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256];
@@ -13000,13 +13081,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->name = GGML_VK_NAME + std::to_string(i);
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
- ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+ ctx->pci_id = ggml_backend_vk_get_device_pci_id(i);
ctx->id = ggml_backend_vk_get_device_id(i);
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
/* .context = */ ctx,
});
+
+ // Gather additional information about the device
+ int dev_idx = vk_instance.device_indices[i];
+ vk::PhysicalDeviceProperties props1;
+ vk_devices[dev_idx].getProperties(&props1);
+ vk::PhysicalDeviceProperties2 props2;
+ vk::PhysicalDeviceIDProperties device_id_props;
+ vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props;
+ vk::PhysicalDeviceDriverProperties driver_props;
+ props2.pNext = &device_id_props;
+ device_id_props.pNext = &pci_bus_props;
+ pci_bus_props.pNext = &driver_props;
+ vk_devices[dev_idx].getProperties2(&props2);
+ std::ostringstream oss;
+ oss << std::hex << std::setfill('0');
+ oss << "GPU-";
+ int byteIdx = 0;
+ for (int i = 0; i < 16; ++i, ++byteIdx) {
+ oss << std::setw(2) << static_cast<int>(device_id_props.deviceUUID[i]);
+ if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) {
+ oss << '-';
+ }
+ }
+ ctx->uuid = oss.str();
+ ctx->pci_bus_id = pci_bus_props.pciBus;
+ ctx->pci_device_id = pci_bus_props.pciDevice;
+ ctx->pci_domain_id = pci_bus_props.pciDomain;
+ ctx->id = std::to_string(i);
+ ctx->major = 0;
+ ctx->minor = 0;
+ // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string
+ ctx->driver_major = 0;
+ ctx->driver_minor = 0;
}
initialized = true;
}
--
2.51.0

View File

@@ -0,0 +1,137 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Santosh Bhavani <santosh.bhavani@live.com>
Date: Wed, 15 Oct 2025 09:29:51 -0700
Subject: [PATCH] NVML fallback for unified memory GPUs
---
ggml/src/mem_nvml.cpp | 71 +++++++++++++++++++++++++++++++++++++++++--
1 file changed, 68 insertions(+), 3 deletions(-)
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
index c9073cef..f473a2a2 100644
--- a/ggml/src/mem_nvml.cpp
+++ b/ggml/src/mem_nvml.cpp
@@ -13,6 +13,7 @@
#include <filesystem>
#include <mutex>
#include <array>
+#include <cstring>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
@@ -23,6 +24,8 @@
#else
# include <dlfcn.h>
# include <unistd.h>
+# include <fstream>
+# include <string>
#endif
namespace fs = std::filesystem;
@@ -79,12 +82,36 @@ struct {
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
+ nvmlReturn_t (*nvmlDeviceGetName)(nvmlDevice_t, char *, unsigned int);
const char * (*nvmlErrorString)(nvmlReturn_t result);
-} nvml { NULL, NULL, NULL, NULL, NULL };
+} nvml { NULL, NULL, NULL, NULL, NULL, NULL, NULL };
static std::mutex ggml_nvml_lock;
extern "C" {
+#ifndef _WIN32
+// Helper function to get available memory from /proc/meminfo on Linux
+// Returns MemAvailable as calculated by the kernel
+static size_t get_mem_available() {
+ std::ifstream meminfo("/proc/meminfo");
+ if (!meminfo.is_open()) {
+ return 0;
+ }
+
+ std::string line;
+ while (std::getline(meminfo, line)) {
+ if (line.find("MemAvailable:") == 0) {
+ size_t available_kb;
+ sscanf(line.c_str(), "MemAvailable: %zu kB", &available_kb);
+ // Convert from kB to bytes
+ return available_kb * 1024;
+ }
+ }
+
+ return 0;
+}
+#endif
+
int ggml_nvml_init() {
std::lock_guard<std::mutex> lock(ggml_nvml_lock);
if (nvml.handle != NULL) {
@@ -117,8 +144,9 @@ int ggml_nvml_init() {
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
+ nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetName");
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
- if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL || nvml.nvmlErrorString == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
@@ -151,8 +179,9 @@ int ggml_nvml_init() {
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
+ nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) dlsym(nvml.handle, "nvmlDeviceGetName");
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
- if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
dlclose(nvml.handle);
nvml.handle = NULL;
@@ -199,10 +228,46 @@ int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
}
nvmlMemory_t memInfo = {0};
status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo);
+
if (status == NVML_SUCCESS) {
+ // NVML working correctly, use its values
*free = memInfo.free;
*total = memInfo.total;
+ return NVML_SUCCESS;
}
+
+#ifndef _WIN32
+ // Handle NVML_ERROR_NOT_SUPPORTED - this indicates NVML doesn't support
+ // reporting framebuffer memory (e.g., unified memory GPUs where FB memory is 0)
+ if (status == NVML_ERROR_NOT_SUPPORTED) {
+ // Use system memory from /proc/meminfo
+ size_t mem_available = get_mem_available();
+ size_t mem_total = 0;
+
+ // Read MemTotal
+ std::ifstream meminfo("/proc/meminfo");
+ if (meminfo.is_open()) {
+ std::string line;
+ while (std::getline(meminfo, line)) {
+ if (line.find("MemTotal:") == 0) {
+ size_t total_kb;
+ sscanf(line.c_str(), "MemTotal: %zu kB", &total_kb);
+ mem_total = total_kb * 1024;
+ break;
+ }
+ }
+ }
+
+ if (mem_total > 0) {
+ *total = mem_total;
+ *free = mem_available;
+ GGML_LOG_INFO("%s NVML not supported for memory query, using system memory (total=%zu, available=%zu)\n",
+ __func__, mem_total, mem_available);
+ return NVML_SUCCESS;
+ }
+ }
+#endif
+
return status;
}

View File

@@ -0,0 +1,49 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Julius Tischbein <ju.tischbein@gmail.com>
Date: Wed, 15 Oct 2025 13:54:15 +0200
Subject: [PATCH] CUDA: Changing the CUDA scheduling strategy to spin (#16585)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* CUDA set scheduling strategy to spinning for cc121
* Using prop.major and prop.minor, include HIP and MUSA
* Exclude HIP and MUSA
* Remove trailing whitespace
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* Remove empty line
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---
ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 6a278b5e9..87941f872 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -340,6 +340,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
+
+ // Temporary performance fix:
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
+ // TODO: Check for future drivers the default scheduling strategy and
+ // remove this call again when cudaDeviceScheduleSpin is default.
+ if (prop.major == 12 && prop.minor == 1) {
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
+ }
+
#endif // defined(GGML_USE_HIP)
}

View File

@@ -567,6 +567,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
(gpus[0].Library == "cpu" && s.options.UseMMap == nil) ||
(gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
}
@@ -927,7 +928,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
}
}
libraryGpuLayers := assignLayers(layers, gl, s.options.NumGPU, lastUsedGPU)
libraryGpuLayers := assignLayers(layers, gl, requireFull, s.options.NumGPU, lastUsedGPU)
if libraryGpuLayers.Sum() > gpuLayers.Sum() {
gpuLayers = libraryGpuLayers
}
@@ -993,7 +994,7 @@ nextLayer:
}
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
func assignLayers(layers []uint64, gpus discover.GpuInfoList, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
func assignLayers(layers []uint64, gpus discover.GpuInfoList, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
// If we can't fit everything then prefer offloading layers other than the output layer
for range 2 {
// requestedLayers may be -1 if nothing was requested
@@ -1002,14 +1003,14 @@ func assignLayers(layers []uint64, gpus discover.GpuInfoList, requestedLayers in
if !envconfig.SchedSpread() {
for i := lastUsedGPU; i < len(gpus); i++ {
// Try to pack things into as few GPUs as possible
forceRequest := i == len(gpus)-1
forceRequest := i == len(gpus)-1 && !requireFull
gpuLayers = findBestFit(layers, gpus[:i+1], requestedLayers, forceRequest)
if gpuLayers.Sum() == len(layers) || gpuLayers.Sum() == requestedLayers {
break
}
}
} else {
gpuLayers = findBestFit(layers, gpus, requestedLayers, true)
gpuLayers = findBestFit(layers, gpus, requestedLayers, !requireFull)
}
// We only stop if we've gotten all of the layers - even if we got requestedLayers, we still

View File

@@ -127,6 +127,14 @@ func TestLLMServerFitGPU(t *testing.T) {
requireFull: true,
expectedErr: ErrLoadRequiredFull,
},
{
name: "requireFull numGPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 4,
requireFull: true,
expectedErr: ErrLoadRequiredFull,
},
}
for _, tt := range tests {

View File

@@ -57,7 +57,8 @@ var initDevices = sync.OnceFunc(func() {
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
gpus = append(gpus, d)
}
@@ -470,7 +471,9 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
// Mimic llama runner logs summarizing layers and memory
gpuLayers := 0
for layer := range maps.Values(b.layers) {
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
switch C.ggml_backend_dev_type(layer.d) {
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
gpuLayers++
}
}
@@ -479,7 +482,8 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
switch C.ggml_backend_dev_type(b.output) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
slog.Info("offloading output layer to CPU")
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
slog.Info("offloading output layer to GPU")
gpuLayers++
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
@@ -722,6 +726,9 @@ func (b *Backend) BackendDevices() []ml.DeviceInfo {
}
info.PCIID = fmt.Sprintf("%02x:%02x.%x", props.pci_bus_id, props.pci_device_id, props.pci_domain_id)
info.LibraryPath = ggml.LibPaths()
if props.numeric_id != nil {
info.FilteredID = C.GoString(props.numeric_id)
}
C.ggml_backend_dev_memory(dev, &props.memory_free, &props.memory_total)
info.TotalMemory = (uint64)(props.memory_total)

View File

@@ -20,10 +20,14 @@ include /src/ggml-cuda/vendors/
include /src/ggml-cuda/template-instances/
include /src/ggml-hip/
include /src/ggml-metal/
include src/ggml-vulkan/
include src/ggml-vulkan/vulkan-shaders
include CMakeLists.txt
include *.[chm]
include *.cpp
include *.cu
include *.cuh
include *.metal
include *.comp
include *.glsl
hide *

View File

@@ -178,6 +178,8 @@ extern "C" {
int pci_device_id;
int pci_domain_id;
const char *library;
// number with which the devices are accessed (Vulkan)
const char *numeric_id;
};
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);

View File

@@ -340,6 +340,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
// Temporary performance fix:
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
// TODO: Check for future drivers the default scheduling strategy and
// remove this call again when cudaDeviceScheduleSpin is default.
if (prop.major == 12 && prop.minor == 1) {
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
}
#endif // defined(GGML_USE_HIP)
}

View File

@@ -0,0 +1,211 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
cmake_policy(SET CMP0116 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED)
function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
else()
find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
endif()
set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
endfunction()
# Function to test shader extension support
# Parameters:
# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product")
# TEST_SHADER_FILE - Path to the test shader file
# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result
function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)
execute_process(
COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error
)
if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*")
message(STATUS "${EXTENSION_NAME} not supported by glslc")
set(${RESULT_VARIABLE} OFF PARENT_SCOPE)
else()
message(STATUS "${EXTENSION_NAME} supported by glslc")
set(${RESULT_VARIABLE} ON PARENT_SCOPE)
add_compile_definitions(${RESULT_VARIABLE})
# Ensure the extension support is forwarded to vulkan-shaders-gen
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE)
endif()
endfunction()
if (Vulkan_FOUND)
message(STATUS "Vulkan found")
ggml_add_backend_library(ggml-vulkan
ggml-vulkan.cpp
../../include/ggml-vulkan.h
)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
# Test all shader extensions
test_shader_extension_support(
"GL_KHR_cooperative_matrix"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp"
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_NV_cooperative_matrix2"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp"
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_integer_dot_product"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_bfloat16"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp"
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
)
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
# Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
endif()
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
endif()
if (GGML_VULKAN_DEBUG)
add_compile_definitions(GGML_VULKAN_DEBUG)
endif()
if (GGML_VULKAN_MEMORY_DEBUG)
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
endif()
if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON)
endif()
if (GGML_VULKAN_VALIDATE)
add_compile_definitions(GGML_VULKAN_VALIDATE)
endif()
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
if (CMAKE_CROSSCOMPILING)
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
else()
detect_host_compiler()
if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
message(FATAL_ERROR "Host compiler not found")
else()
message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
endif()
else()
# For non-cross-compiling, use empty toolchain (use host compiler)
set(HOST_CMAKE_TOOLCHAIN_FILE "")
endif()
include(ExternalProject)
if (CMAKE_CROSSCOMPILING)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
endif()
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
-DCMAKE_INSTALL_BINDIR=.
-DCMAKE_BUILD_TYPE=$<CONFIG>
${VULKAN_SHADER_GEN_CMAKE_ARGS}
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
BUILD_ALWAYS TRUE
# NOTE: When DESTDIR is set using Makefile generators and
# "make install" triggers the build step, vulkan-shaders-gen
# would be installed into the DESTDIR prefix, so it is unset
# to ensure that does not happen.
INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
${CMAKE_COMMAND} --install . --config $<CONFIG>
)
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
# Because external projects do not provide source-level tracking,
# the vulkan-shaders-gen sources need to be explicitly added to
# ensure that changes will cascade into shader re-generation.
file(GLOB _ggml_vk_shaders_gen_sources
CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp"
"${_ggml_vk_input_dir}/*.h")
add_custom_command(
OUTPUT ${_ggml_vk_header}
COMMAND ${_ggml_vk_genshaders_cmd}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
DEPENDS ${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders header"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})
foreach (file_full ${_ggml_vk_shader_files})
get_filename_component(file ${file_full} NAME)
set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp")
add_custom_command(
OUTPUT ${_ggml_vk_target_cpp}
DEPFILE ${_ggml_vk_target_cpp}.d
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
--source ${file_full}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
--target-cpp ${_ggml_vk_target_cpp}
DEPENDS ${file_full}
${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders for ${file}"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp})
endforeach()
else()
message(WARNING "Vulkan not found")
endif()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,31 @@
cmake_minimum_required(VERSION 3.19)
project("vulkan-shaders-gen" C CXX)
find_package (Threads REQUIRED)
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat glslc support")
endif()
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat2 glslc support")
endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
message(STATUS "Enabling dot glslc support")
endif()
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
message(STATUS "Enabling bfloat16 glslc support")
endif()
if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
message(STATUS "Enabling shader debug info")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)

View File

@@ -0,0 +1,29 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
}
}

View File

@@ -0,0 +1,69 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif
void main() {
uint idx = get_idx();
uint orig_idx = idx;
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
FLOAT_TYPE sum_sq = 0;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
sum_sq += sum*sum;
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
idx += num_threads;
}
#if ADD_RMS
if (p.param3 != 0) {
// reduce the sum within each subgroup, then across subgroups
const uint NumSubgroups = num_threads / gl_SubgroupSize;
sum_sq = subgroupAdd(sum_sq);
if (gl_SubgroupInvocationID == 0) {
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
sum_sq += sumsh[gl_SubgroupID + s];
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
}
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
}
}
#endif
}

View File

@@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint ne0;
uint ne1;
uint s01;
uint s02;
uint s11;
uint s21;
} p;
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i1 = gl_WorkGroupID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i11 = data_c[i1 + i2 * p.s21];
const uint s1 = p.ne0;
const uint s2 = p.ne0 * p.ne1;
const uint d0 = i1 * s1 + i2 * s2;
const uint a0 = i1 * p.s01 + i2 * p.s02;
const uint b0 = i11 * p.s11;
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
}
}

View File

@@ -0,0 +1,60 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define FLT_MAX 3.402823466e+38F
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
shared uint tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
if (row >= p.KY) {
return;
}
A_TYPE amax = -FLT_MAX;
uint acol = col;
if (col < p.KX) {
amax = data_a[row*p.KX + col];
}
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
A_TYPE val = data_a[row*p.KX + i];
if (val > amax) {
amax = val;
acol = i;
}
}
tmp[col] = acol;
tmpmax[col] = amax;
barrier();
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
if (col < s && col + s < p.KX) {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
}
}
barrier();
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
}
}

View File

@@ -0,0 +1,79 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
#define ASC 0
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint order;
} p;
shared int dst_row[BLOCK_SIZE];
shared A_TYPE a_sh[BLOCK_SIZE];
void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx0] = dst_row[idx1];
dst_row[idx1] = tmp;
}
void argsort(bool needs_bounds_check) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
// initialize indices
dst_row[col] = col;
a_sh[col] = data_a[row_offset + col];
barrier();
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;
int sh_idx_0 = dst_row[idx_0];
int sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
swap(idx_0, idx_1);
}
barrier();
}
}
if (col < p.ncols) {
if (p.order == ASC) {
data_d[row_offset + col] = dst_row[col];
} else {
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
}
}
}
void main() {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
} else {
argsort(true);
}
}

View File

@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}

View File

@@ -0,0 +1,41 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const int dim = p.param3;
if (idx >= p.ne) {
return;
}
const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
const uint i2_offset = i2*p.ne21*p.ne20;
const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
uint o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
#else
if (is_src0) {
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
} else {
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
}
#endif
}

View File

@@ -0,0 +1,49 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : require
const uint num_threads = 128;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 4;
// fast path for when all four iterations are in-bounds
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
} else {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
}
}

View File

@@ -0,0 +1,105 @@
#version 450
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint ne;
uint batches;
uint channels;
uint dst_w;
uint dst_h;
uint src_w;
uint src_h;
uint knl_w;
uint knl_h;
int stride_x;
int stride_y;
int pad_x;
int pad_y;
int dilation_x;
int dilation_y;
} p;
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
uint i0 = idx / p.dst_w;
uint dst_x = idx - i0 * p.dst_w;
uint i1 = i0 / p.dst_h;
uint dst_y = i0 - i1 * p.dst_h;
uint n = i1 / p.channels;
uint c = i1 - n * p.channels;
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
uint knl_i = c * p.knl_h * p.knl_w;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
sum = fma(v, k, sum);
}
}
return sum;
}
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
uint i0 = idx / p.channels;
uint c = idx - i0 * p.channels;
uint i1 = i0 / p.dst_w;
uint dst_x = i0 - i1 * p.dst_w;
uint n = i1 / p.dst_h;
uint dst_y = i1 - n * p.dst_h;
uint src_i = n * p.channels * p.src_h * p.src_w;
uint src_row = p.src_w * p.channels;
uint knl_row = p.knl_w * p.channels;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
sum = fma(v, k, sum);
}
}
return sum;
}
void main() {
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
FLOAT_TYPE result =
#ifdef WHCN
conv_2d_dw_whcn(idx);
#else
conv_2d_dw_cwhn(idx);
#endif
dst_data[idx] = D_TYPE(result);
}

View File

@@ -0,0 +1,349 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [W, H, Cin, N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, Cout, N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t Cout;
uint32_t Cin;
uint32_t N;
// Tensor spatial sizes: kernel, input, output
uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
// Parameters: stride, padding, dilation - 0=y, 1=x
uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
#ifdef TRANSPOSE
uint32_t s0mp; uint32_t s0L;
uint32_t s1mp; uint32_t s1L;
#endif
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
layout(constant_id = 6) const uint SHMEM_PAD = 4;
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.Cout;
uint32_t CRS = p.Cin * p.KH * p.KW;
uint32_t NPQ = p.N * p.OH * p.OW;
uint32_t n_elems_out = K * NPQ;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#ifdef COOPMAT2
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_numel = BS_K * BS_CRS;
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_K = BS_K / TS_K;
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=Cout
C=Cin
R,S=KH,KW
P,Q=OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
uint32_t Cin_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
#ifdef USE_COLLECTIVES
uint32_t cached_CRS_idx;
uint32_t cached_Cin_idx;
uint32_t cached_KH_idx;
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif
/* Load kernel to A_block: (BS_K x BS_CRS)*/
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
#ifdef TRANSPOSE
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
#else
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
#endif
float val = knl_data[knl_idx];
if (K_idx >= K || CRS_idx_a >= CRS) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
uint32_t CRS_idx_b;
uint32_t Cin_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
#ifdef USE_COLLECTIVES
if (use_collectives == 1) {
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif
#ifdef TRANSPOSE
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
#else
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
#endif
uint32_t src_idx =
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
float val = src_data[src_idx];
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
#ifdef TRANSPOSE
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
#endif
) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#ifdef COOPMAT2
coopMatPerElementNV(matC, matC, perElemOpStore);
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
}
}
}
}
#endif
}

View File

@@ -0,0 +1,98 @@
#version 450
#include "types.glsl"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
layout (push_constant) uniform parameter {
uint32_t Cout;
uint32_t Cin;
uint32_t K;
uint32_t L;
uint32_t KL;
uint32_t nb01;
uint32_t nb02;
uint32_t nb11;
uint32_t nb1;
int32_t s0;
} p;
uint32_t Cout_idx = gl_WorkGroupID.x;
const uint32_t bs = gl_WorkGroupSize.x;
uint32_t tid = gl_LocalInvocationID.x;
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
uint32_t tmp_len = bs*p.s0+p.K;
shared D_TYPE tmp[4096];
uint splitWork(uint workSize){
return (bs + workSize -1) / bs;
}
void main(){
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx < tmp_len){
tmp[idx] = 0.0;
}
}
uint32_t L_blocks = splitWork(p.L);
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
if(L_block_id > 0){
barrier();
// Shift values in tmp to the current processing window
for(int i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx >= bs*p.s0 && idx < tmp_len){
tmp[idx-bs*p.s0] = tmp[idx];
tmp[idx] = 0.0;
}else if(idx >= p.K && idx < bs*p.s0){
tmp[idx] = 0.0;
}
}
}
barrier();
// Save contributions of the block to tmp
uint32_t L_idx = L_block_id*bs + tid;
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
D_TYPE dp = 0.0;
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
if(L_idx < p.L){
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
dp = fma(elemKrn, elemInp, dp);
}
}
tmp[tid*p.s0 + K_idx] += dp;
barrier();
}
// Save the computed values except the last block that can have different size
uint32_t KLb_idx = L_block_id*bs*p.s0;
if(L_block_id < L_blocks-1){
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
uint32_t sh_idx = p.s0*tid+s0_idx;
uint32_t KL_idx = KLb_idx+sh_idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
}
}
}
}
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
}
}
}

View File

@@ -0,0 +1,23 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
#endif
}

View File

@@ -0,0 +1,51 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
#include "dequant_funcs.glsl"
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = get_doffset() + dst_idx(idx);
uint src_idx = src0_idx_quant(idx, QUANT_K);
const uint a_offset = 0;
const uint ib = src_idx;
const vec2 dm = get_dm(ib, a_offset);
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
vec4 v = dequantize4(ib, j / QUANT_R, a_offset);
v = v * dm.x + vec4(dm.y);
#if QUANT_R == 2
data_d[dst_idx + j/2 + 0] = v[0];
data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];
data_d[dst_idx + j/2 + 1] = v[2];
data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];
#else
data_d[dst_idx + j + 0] = v[0];
data_d[dst_idx + j + 1] = v[1];
data_d[dst_idx + j + 2] = v[2];
data_d[dst_idx + j + 3] = v[3];
#endif
}
}

View File

@@ -0,0 +1,296 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#if defined(SET_ROWS) && QUANT_K == 1
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 512;
#else
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 32;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
#if defined(SET_ROWS)
#include "generic_binary_head.glsl"
layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
#if B_SIZE == 64
#define DATA_I_SWIZZLE .x
#else
#define DATA_I_SWIZZLE
#endif
#else
#include "generic_unary_head.glsl"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#endif
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
const uint xi0 = min(15, int(x0 + 8.5));
const uint xi1 = min(15, int(x1 + 8.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q4_1)
void quantize(uint dst_idx, uint src_idx)
{
float vmin = 1.0/0.0;
float vmax = -vmin;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
const float v = data_s[src_idx + j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(vmin);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
const uint xi0 = min(15, int(x0 + 0.5));
const uint xi1 = min(15, int(x1 + 0.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q5_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
const uint xi0 = min(31, int(x0 + 16.5));
const uint xi1 = min(31, int(x1 + 16.5));
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
}
data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
}
#endif
#if defined(DATA_A_Q5_1)
void quantize(uint dst_idx, uint src_idx)
{
float min = data_s[src_idx + 0];
float max = min;
[[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
const float v = data_s[src_idx + j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = (d != 0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(min);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - min)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
const uint xi0 = uint(x0 + 0.5);
const uint xi1 = uint(x1 + 0.5);
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
}
data_q[dst_idx].qh = qh;
}
#endif
#if defined(DATA_A_Q8_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0; // absolute max
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
const float v = data_s[src_idx + j];
amax = max(amax, abs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
const float x0 = data_s[src_idx + j]*id;
data_q[dst_idx].qs[j] = int8_t(round(x0));
}
}
#endif
#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;
if (x >= kvalues_iq4nl[15]) return 15;
int ml = 0, mu = 15;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
}
return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
}
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = (d != 0.0) ? 1.0/d : 0.0;
float sumqx = 0, sumq2 = 0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
const uint xi0 = best_index(x0);
const uint xi1 = best_index(x1);
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
}
#endif
#if defined(DATA_A_BF16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
}
#endif
#if defined(SET_ROWS)
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
uint i12 = fastmod(i03, p.ne12);
uint i11 = fastmod(i02, p.ne11);
uint i10 = i01;
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
quantize(dst_idx, src0_idx);
}
#else
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = dst_idx_quant(idx, QUANT_K);
uint src_idx = get_aoffset() + src0_idx(idx);
quantize(dst_idx, src_idx);
}
#endif

View File

@@ -0,0 +1,17 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}

View File

@@ -0,0 +1,31 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
#include "generic_head.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
const uint CHUNK_SIZE = 512;
void main() {
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
const uint col = gl_LocalInvocationID.x;
uint count = 0;
[[unroll]]
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
const uint idx = base + i + col;
if (idx >= p.KX) {
break;
}
count += uint(data_a[idx] == data_b[idx]);
}
atomicAdd(data_d[0], D_TYPE(count));
}

View File

@@ -0,0 +1,20 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_GlobalInvocationID.x * 16;
if (i >= p.nel) {
return;
}
[[unroll]] for (uint l = 0; l < 16; l++) {
data_b[i + l] = D_TYPE(data_a[i + l]);
}
}

View File

@@ -0,0 +1,616 @@
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
#endif
#if defined(DATA_A_F16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
#endif
#if defined(DATA_A_BF16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
}
#endif
#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
}
#endif
#if defined(DATA_A_Q4_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
}
#endif
#if defined(DATA_A_Q5_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f);
}
#endif
#if defined(DATA_A_Q5_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = data_a[a_offset + ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y);
}
#endif
#if defined(DATA_A_Q8_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_IQ1_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const int i8 = int(iqs % 8);
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qs = data_a[a_offset + ib].qs[ib8];
const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
// Signed bitfield extract.
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
return dl * (vec2(gvec) + delta);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const int i8 = int(iqs % 8);
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qs = data_a[a_offset + ib].qs[ib8];
const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
// Signed bitfield extract.
const ivec4 gvec = ivec4(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2),
bitfieldExtract(grid, 2 * (i8 + 2), 2),
bitfieldExtract(grid, 2 * (i8 + 3), 2)
);
return dl * (vec4(gvec) + delta);
}
#endif
#if defined(DATA_A_IQ1_M)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib8 = iqs / 8;
const uint ib16 = iqs / 16;
const int i8 = int(iqs % 8);
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
// Signed bitfield extract.
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
return dl * (vec2(gvec) + delta);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib8 = iqs / 8;
const uint ib16 = iqs / 16;
const int i8 = int(iqs % 8);
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
// Signed bitfield extract.
const ivec4 gvec = ivec4(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2),
bitfieldExtract(grid, 2 * (i8 + 2), 2),
bitfieldExtract(grid, 2 * (i8 + 3), 2)
);
return dl * (vec4(gvec) + delta);
}
#endif
#if defined(DATA_A_IQ2_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
return db * vec2(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
return db * vec4(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ4_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec4 qs = u8vec4(
data_a[a_offset + ib].qs[iq + 0],
data_a[a_offset + ib].qs[iq + 1],
data_a[a_offset + ib].qs[iq + 2],
data_a[a_offset + ib].qs[iq + 3]
);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec4(
kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
}
#endif
#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}
#endif
#if defined(DATA_A_IQ1_M)
vec2 get_dm(uint ib, uint a_offset) {
const uint16_t[4] scales = data_a[a_offset + ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
return vec2(d, 0);
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
}
#endif
#if defined(DATA_A_Q2_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q3_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 64; // 0,1
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
const uint j = (iqs % 64) / 4; // 0..3
const uint is = iqs / 8; // 0..15
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
| (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[a_offset + ib].d) * float(us - 32);
return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q4_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m),
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q5_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q6_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 64; // 0,1
const uint b = (iqs % 64) / 32; // 0,1
const uint is_b = (iqs % 16) / 8; // 0,1
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uint is = 8 * n + qhshift + is_b; // 0..15
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]);
return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif

View File

@@ -0,0 +1,720 @@
#include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
qs >>= shift;
qs &= 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
block_q5_0 block;
};
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
block_q5_1 block;
};
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = bl.block.qh;
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs | qh) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
// Load 16b and select the byte for this element
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
float16_t ret = float16_t(qs) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
block_q2_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
block_q2_K_packed16 block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qs = (qs >> qsshift) & 0x0303;
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
block_q3_K block;
};
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint n = iqs / 128; // 0,1
const uint qsi = n * 32 + (iqs % 32); // 0..63
const uint hmi = (iqs % 32); // 0..31
const uint j = (iqs % 128) / 8; // 0..15
const uint is = iqs / 16; // 0..15
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
uint32_t scaleidx1 = is + 8 - (is/4)*4;
uint32_t scaleidx1shift = (is/4)*2;
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
const float16_t dl = bl.block.d * float16_t(us - 32);
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
block_q4_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
block_q4_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};
#if defined(IS_MUL_MM2)
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
// into shared memory and then process the whole tile using those scales.
// There is a fetch function that loads into private variables and then a store
// function that stores into shared memory.
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
// the part that fetches from the structure (which has a different block layout).
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
const uint shAscales_stride = (BM + 2);
// 1 scale per 32 elements -> 8 scales per block, per row
shared vec2 shAscales[8 * shAscales_stride];
uvec4 row_v;
#endif
#if defined(DATA_A_Q4_K)
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q4_k_packed128[block_index].q4k[0];
}
}
#endif
#if defined(DATA_A_Q5_K)
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q5_k_packed128[block_index].q5k[0];
}
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
void store_scalesQ4_K(uint tid)
{
barrier();
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
uint is = idx + is_start;
uvec4 v = row_v;
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
}
barrier();
}
#endif
#endif
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q4k[0];
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
#endif
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float ret = d * float(qs) - m;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
block_q5_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
block_q5_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
block_q5_K_packed128 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
#endif
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = ((qh >> is) & 0x101) << 4;
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs | qh)[idx & 1];
float ret = d * float(qs) - m;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
block_q6_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
block_q6_K_packed16 block;
};
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x40) >> 6; // 0,1
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint is = (idx & 0xF0) >> 4; // 0..15
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
ql = (ql >> (b * 4)) & 0x0F0F;
uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qh = ((qh >> qhshift) & 0x0303) << 4;
int q = unpack8(ql | qh)[idx & 1];
float16_t ret = dscale * float16_t(q - 32);
return ret;
}
#if defined(DATA_A_IQ1_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
block_iq1_s block;
};
float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5;
const uint ib8 = (idx & 0xF8) >> 3;
const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
return ret;
}
#endif
#if defined(DATA_A_IQ1_M)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
block_iq1_m block;
};
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
block_iq1_m_packed64 block;
};
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const uint idx = coordInBlock[1];
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
const uint ib8 = (idx & 0xF8) >> 3;
const uint ib16 = (idx & 0xF0) >> 4;
const int i8 = int(idx % 8);
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
return ret;
}
#endif
#if defined(DATA_A_IQ2_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
block_iq2_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
block_iq2_xxs_packed16 block;
};
float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0x18) >> 3; // 0..3
const uint iqs = 8 * ib32 + ib8;
const uint qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
sign |= bitCount(sign) << 7;
uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
#if defined(DATA_A_IQ2_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
block_iq2_xs block;
};
float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint is = (idx & 0xE0) >> 5; // 0..8
const uint sshift = (idx & 0x10) >> 2; // 0,4
const uint iqs = (idx & 0xF8) >> 3; // 0..63
const uint16_t qs = bl.block.qs[iqs];
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
uint sign = uint(qs >> 9);
sign |= bitCount(sign) << 7;
uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
#if defined(DATA_A_IQ2_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
block_iq2_s block;
};
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
const uint qhshift = 2 * (ib8 % 4);
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib32];
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
const float d = float(bl.block.d);
const float db = d * 0.25 * (0.5 + scale);
const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ3_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
block_iq3_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
block_iq3_xxs_packed16 block;
};
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
uint idx = coordInBlock[1];
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint signs = pack32(u16vec2(
bl16.block.qs[is/2+0],
bl16.block.qs[is/2+1]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ3_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
block_iq3_s block;
};
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint iqh = (idx & 0xE0) >> 5;
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint qh = bl.block.qh[iqh];
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
const uint scale = bl.block.scales[iqs / 16];
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ4_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
block_iq4_xs block;
};
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 16) >> 2;
const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
return ret;
}
#endif
#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
};
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
return ret;
}
#endif
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1
#elif defined(DATA_A_Q5_0)
#define dequantFuncA dequantFuncQ5_0
#elif defined(DATA_A_Q5_1)
#define dequantFuncA dequantFuncQ5_1
#elif defined(DATA_A_Q8_0)
#define dequantFuncA dequantFuncQ8_0
#elif defined(DATA_A_Q2_K)
#define dequantFuncA dequantFuncQ2_K
#elif defined(DATA_A_Q3_K)
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#define fetch_scales fetch_scalesQ4_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#define fetch_scales fetch_scalesQ5_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ1_S)
#define dequantFuncA dequantFuncIQ1_S
#elif defined(DATA_A_IQ1_M)
#define dequantFuncA dequantFuncIQ1_M
#elif defined(DATA_A_IQ2_XXS)
#define dequantFuncA dequantFuncIQ2_XXS
#elif defined(DATA_A_IQ2_XS)
#define dequantFuncA dequantFuncIQ2_XS
#elif defined(DATA_A_IQ2_S)
#define dequantFuncA dequantFuncIQ2_S
#elif defined(DATA_A_IQ3_XXS)
#define dequantFuncA dequantFuncIQ3_XXS
#elif defined(DATA_A_IQ3_S)
#define dequantFuncA dequantFuncIQ3_S
#elif defined(DATA_A_IQ4_XS)
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#endif

View File

@@ -0,0 +1,13 @@
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint M;
uint K;
uint stride_a;
uint stride_b;
uint nel;
} p;
#include "types.glsl"

View File

@@ -0,0 +1,42 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint ib64 = ib32 / 2;
const uint b_idx = 256 * ib + 32 * ib32;
const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint sc = data_a[ib].scales[ib64];
[[unroll]] for (int l = 0; l < 4; ++l) {
const uint ib16 = 2 * ib32 + l / 2;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
const uint qs = data_a[ib].qs[4 * ib32 + l];
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
[[unroll]] for (int j = 0; j < 8; ++j) {
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
}
}
}

View File

@@ -0,0 +1,35 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
uint qh = data_a[ib].qh[ib32];
const float d = float(data_a[ib].d);
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qs = data_a[ib].qs[4 * ib32 + l];
const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
[[unroll]] for (int j = 0; j < 8; ++j) {
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
}
}
}

View File

@@ -0,0 +1,44 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
uint qh = data_a[ib].qh[ib32];
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint qs = data_a[ib].qs[4 * ib32 + l];
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
qs |= (qh << (8 - 2 * l)) & 0x300;
const uvec2 grid = iq2s_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -0,0 +1,43 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint16_t qs = data_a[ib].qs[4 * ib32 + l];
const uint sign7 = qs >> 9;
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uvec2 grid = iq2xs_grid[qs & 511];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -0,0 +1,49 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[8*is + 4],
data_a[ib].qs[8*is + 5],
data_a[ib].qs[8*is + 6],
data_a[ib].qs[8*is + 7]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uint qs = data_a[ib].qs[8 * is + l];
const uvec2 grid = iq2xxs_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -0,0 +1,40 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale nibble.
// Each block contains 4 scale bytes (8 scales) for 256 output values.
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
uint qh = data_a[ib].qh[is];
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = 8 * is + l;
const uint qs = data_a[ib].qs[iqs];
const uint gidx = qs | ((qh << (8 - l)) & 256);
const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));
const u8vec4 grid = unpack8(iq3s_grid[gidx]);
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -0,0 +1,51 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// 8 threads handle 1 superblock
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const uint s_idx = QUANT_K / 4 + 4 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[s_idx + 0],
data_a[ib].qs[s_idx + 1],
data_a[ib].qs[s_idx + 2],
data_a[ib].qs[s_idx + 3]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.5;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
// Restore parity bit.
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint qs0 = data_a[ib].qs[8 * is + 2 * l];
const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];
const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);
const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -0,0 +1,32 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = float(data_a[ib].d);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@@ -0,0 +1,34 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (1 scale and 32 quantized values)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const float d = float(data_a[ib].d);
// Scales are 6 bits
const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
| (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
const float dl = d * (int(scale) - 32);
const uint b_idx = 256 * ib + 32 * ib32;
const uint q_idx = 16 * ib32;
[[unroll]] for (uint l = 0; l < 16; ++l) {
data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@@ -0,0 +1,32 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@@ -0,0 +1,34 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint ip = tid / 32;
const uint il = tid - 32 * ip;
const uint is = 8 * ip + il / 16;
const uint y_idx = i * QUANT_K + 128 * ip + il;
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
}
}

View File

@@ -0,0 +1,42 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.nel / QUANT_K) {
return;
}
const uint r = gl_LocalInvocationID.x / 4;
const uint tid = r / 2;
const uint is0 = r % 2;
const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
const uint n = tid / 4;
const uint j = tid - 4*n;
const uint8_t m = uint8_t(1 << (4*n + j));
const uint is = 8*n + 2*j + is0;
const uint shift = 2*j;
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
(data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
const uint qs_idx = 32*n;
for (uint l = l0; l < l0 + 4; ++l) {
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
}
}
}

View File

@@ -0,0 +1,30 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = float(data_a[ib].d);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));
data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f));
}
}

View File

@@ -0,0 +1,32 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
}
}

View File

@@ -0,0 +1,68 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint il = tid / 8;
const uint ir = tid % 8;
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
uint scidx0 = (is < 4) ? is : (is + 4);
uint scidx1 = (is < 4) ? is : (is - 4);
uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint scidxshift1 = (is < 4) ? 0 : 2;
uint mbidx0 = is + 4;
uint mbidx1 = (is < 4) ? is + 4 : is;
uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint mbidxshift0 = (is < 4) ? 0 : 4;
uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint mbidxshift1 = (is < 4) ? 0 : 2;
uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * mbyte;
scidx0 = (is < 4) ? is + 1 : (is + 5);
scidx1 = (is < 4) ? is + 1 : (is - 3);
scidxmask1 = (is < 4) ? 0x30 : 0xC0;
scidxshift1 = (is < 4) ? 0 : 2;
mbidx0 = is + 5;
mbidx1 = (is < 4) ? is + 5 : is + 1;
mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbidxshift0 = (is < 4) ? 0 : 4;
mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * mbyte;
[[unroll]] for (uint l = 0; l < n; ++l) {
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1);
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2);
}
}
}

View File

@@ -0,0 +1,34 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
}
}

View File

@@ -0,0 +1,35 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint qh = data_a[ib].qh;
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
}
}

View File

@@ -0,0 +1,70 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint il = tid / 16;
const uint ir = tid % 16;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
const uint qh_idx = 2 * ir;
uint scidx0 = (is < 4) ? is : (is + 4);
uint scidx1 = (is < 4) ? is : (is - 4);
uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint scidxshift1 = (is < 4) ? 0 : 2;
uint mbidx0 = is + 4;
uint mbidx1 = (is < 4) ? is + 4 : is;
uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint mbidxshift0 = (is < 4) ? 0 : 4;
uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint mbidxshift1 = (is < 4) ? 0 : 2;
uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * mbyte;
scidx0 = (is < 4) ? is + 1 : (is + 5);
scidx1 = (is < 4) ? is + 1 : (is - 3);
scidxmask1 = (is < 4) ? 0x30 : 0xC0;
scidxshift1 = (is < 4) ? 0 : 2;
mbidx0 = is + 5;
mbidx1 = (is < 4) ? is + 5 : is + 1;
mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbidxshift0 = (is < 4) ? 0 : 4;
mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * mbyte;
const uint8_t hm1 = uint8_t(1 << (2 * il ));
const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
}
}

View File

@@ -0,0 +1,33 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint ip = tid / 32;
const uint il = tid - 32 * ip;
const uint is = 8 * ip + il / 16;
const uint y_idx = i * QUANT_K + 128 * ip + il;
const uint ql_idx = 64 * ip + il;
const uint8_t qh = data_a[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
}
}

View File

@@ -0,0 +1,31 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 16*il;
const float d = float(data_a[ib].d);
const uint q_idx = 16*il;
[[unroll]] for (uint l = 0; l < 16; l += 2) {
data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
}
}

View File

@@ -0,0 +1,34 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter
{
uint ncols;
uint rows_per_channel;
uint n_past;
} p;
#include "types.glsl"
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint col = gl_GlobalInvocationID.y;
const uint row = gl_GlobalInvocationID.x;
if (col >= p.ncols) {
return;
}
const uint i = row*p.ncols + col;
if (col > p.n_past + row % p.rows_per_channel) {
data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
} else {
data_d[i] = D_TYPE(data_a[i]);
}
}

View File

@@ -0,0 +1,27 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}

View File

@@ -0,0 +1,21 @@
#version 450
#include "rte.glsl"
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(exp(float(data_a[i])));
}

View File

@@ -0,0 +1,383 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][HSK / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
vec4 Of[Br][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
}
float Lf[Br], Mf[Br];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
float slope[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = 1.0;
}
// ALiBi
if (p.max_bias > 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = 0.0;
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
}
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
}
}
}
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
}
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
}
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
}
eMf[r] = exp(Moldf[r] - Mf[r]);
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf;
}
}
}
barrier();
}
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float rowmaxf, eMf;
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
}
barrier();
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
}
barrier();
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += tmpshv4[tid + s];
tmpshv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpshv4[d_tid];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
#endif
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@@ -0,0 +1,202 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask_n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
return ACC_TYPE(data_s[h]);
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
{
N = p.N;
KV = p.KV;
i = gl_WorkGroupID.x;
split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
Tr = CEIL_DIV(N, Br);
start_j = split_k_index * p.split_kv / Bc;
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
iq3 = gl_WorkGroupID.z;
// broadcast factors
rk2 = p.neq2/p.nek2;
rk3 = p.neq3/p.nek3;
rv2 = p.neq2/p.nev2;
rv3 = p.neq3/p.nev3;
// k indices
ik3 = iq3 / rk3;
ik2 = iq2 / rk2;
// v indices
iv3 = iq3 / rv3;
iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
k_stride = p.nb11;
v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}

View File

@@ -0,0 +1,418 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
#define tile_row(r) (row_tid * rows_per_thread + (r))
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
if ((HSK % 16) != 0) {
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
if (i + tid < Br * qstride) {
Qf[i + tid] = f16vec4(0);
}
}
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kshstride) {
ksh[i + tid] = f16vec4(0);
}
}
barrier();
}
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
// ALiBi
if (p.max_bias > 0.0f) {
if (tid < Br) {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
}
barrier();
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
uint coord = gl_SubgroupID * MatBc * sfshstride;
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
}
barrier();
}
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
barrier();
}
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
#endif
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@@ -0,0 +1,320 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_EXT_null_initializer : enable
#include "types.glsl"
#include "dequant_funcs_cm2.glsl"
#include "flash_attn_base.glsl"
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
// Replace matrix elements >= numRows or numCols with 'replace'
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
if (row >= numRows || col >= numCols) {
return replace;
}
return elem;
}
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
{
return exp(elem);
}
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
{
return max(elem0, elem1);
}
#if defined(BLOCK_SIZE)
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c < HSV) {
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if defined(BLOCK_SIZE)
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
k_stride &= ~7;
v_stride &= ~7;
#endif
m_stride &= ~7;
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
// ALiBi
if (p.max_bias > 0.0f) {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
}
// Clear padding elements to -inf, so they don't contribute to rowmax
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
coopMatPerElementNV(M, rowmax, Max, Mold);
coopMatPerElementNV(P, S - M, Exp);
coopMatPerElementNV(eM, Mold - M, Exp);
// Clear padding elements to 0, so they don't contribute to rowsum
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
}
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
// compute rowsum by multiplying by matrix of all ones.
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared
// across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
// resize M by using smear/reduce
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
// O, Ldiag, Mr all have the same type so all element locations match
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
ACC_TYPE sink = S[i];
ACC_TYPE ms = ACC_TYPE(1.0f);
ACC_TYPE vs = ACC_TYPE(1.0f);
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);
O[i] *= ms;
} else {
vs = exp(sink - Mr[i]);
}
Ldiag[i] = Ldiag[i]*ms + vs;
}
}
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
}
O = Ldiag*O;
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
}
}

View File

@@ -0,0 +1,120 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) readonly buffer B {float data_s[];};
layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
uint sinks;
} p;
shared float tmpsh[BLOCK_SIZE];
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m);
}
// reduce across the workgroup
tmpsh[tid] = m_max;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
m_max = max(m_max, tmpsh[tid + s]);
tmpsh[tid] = m_max;
}
barrier();
}
m_max = tmpsh[0];
barrier();
// Compute L based on m_max
float L = 0;
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float l = data_a[l_offset + (k + tid) * lm_stride];
float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l;
}
// reduce across the workgroup
tmpsh[tid] = L;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
L += tmpsh[tid + s];
tmpsh[tid] = L;
}
barrier();
}
L = tmpsh[0];
float sink;
if (p.sinks != 0) {
sink = data_s[n];
float ms = 1.0f;
float vs = 1.0f;
if (sink > m_max) {
ms = exp(m_max - sink);
} else {
vs = exp(sink - m_max);
}
L = L*ms + vs;
}
L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
if (p.sinks != 0) {
if (sink > m_max) {
float ms = 1.0f;
ms = exp(m_max - sink);
O *= ms;
}
}
O *= L;
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
data_d[iq3 * D * N + D * n + d] = O;
}
}

View File

@@ -0,0 +1,13 @@
#version 450
#include "glu_head.glsl"
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
float op(float a, float b) {
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
}
#include "glu_main.glsl"

View File

@@ -0,0 +1,27 @@
#version 450
#include "glu_head.glsl"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
float op(float a, float b) {
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
return 0.5f * a * (1.0f + erf_approx) * b;
}
#include "glu_main.glsl"

View File

@@ -0,0 +1,11 @@
#version 450
#include "glu_head.glsl"
const float GELU_QUICK_COEF = -1.702f;
float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
#include "glu_main.glsl"

View File

@@ -0,0 +1,25 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float xi = float(data_a[i]);
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
}

View File

@@ -0,0 +1,39 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float a = float(data_a[i]);
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
}

View File

@@ -0,0 +1,23 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_QUICK_COEF = -1.702f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
}

View File

@@ -0,0 +1,51 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.glsl"
#include "utils.glsl"
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint misalign_offsets;
float param1; float param2; int param3;
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
layout(constant_id = 0) const bool norepeat = false;
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
}
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
if (norepeat) {
return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
} else {
return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
}
}
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
}

View File

@@ -0,0 +1,9 @@
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint KX;
uint KY;
float param1;
float param2;
} p;

View File

@@ -0,0 +1,76 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint misalign_offsets;
float param1; float param2;
uint ne0_012mp; uint ne0_012L;
uint ne0_01mp; uint ne0_01L;
uint ne0_0mp; uint ne0_0L;
uint ne1_012mp; uint ne1_012L;
uint ne1_01mp; uint ne1_01L;
uint ne1_0mp; uint ne1_0L;
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
uint src0_idx(uint idx) {
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
uint dst_idx(uint idx) {
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
}
uint src0_idx_quant(uint idx, uint qk) {
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00;
}
uint dst_idx_quant(uint idx, uint qk) {
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10;
}

View File

@@ -0,0 +1,42 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = gl_GlobalInvocationID.x;
if (i00 >= p.ne00) {
return;
}
uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(v);
#else
data_d[d_offset + i00] = D_TYPE(v);
#endif
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
}

View File

@@ -0,0 +1,51 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
#include "generic_binary_head.glsl"
#include "dequant_funcs.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2;
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
if (i00 >= p.ne00) {
return;
}
uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint ib = a_offset + i00/QUANT_K; // block index
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
vec2 v = dequantize(ib, iqs, 0);
const vec2 dm = get_dm(ib, 0);
v = v * dm.x + dm.y;
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
}

View File

@@ -0,0 +1,19 @@
#extension GL_EXT_shader_16bit_storage : require
#include "rte.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (push_constant) uniform parameter
{
uint N;
uint ne00;
uint ne20;
uint mode;
float alpha;
float limit;
} p;

View File

@@ -0,0 +1,29 @@
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.N) {
return;
}
const uint row = i / p.ne20;
const uint col = i - row * p.ne20;
if (p.mode == 0) {
// Default
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
} else if (p.mode == 1) {
// Swapped
const uint offset = p.ne00 / 2;
const uint idx = row * p.ne00 + col;
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
} else {
// Split
const uint idx = row * p.ne00 + col;
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
}
}

View File

@@ -0,0 +1,66 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared float tmp[BLOCK_SIZE];
void main() {
const uint group_size = p.KX;
const float eps = p.param1;
const uint tid = gl_LocalInvocationID.x;
const uint start = gl_WorkGroupID.x * group_size + tid;
const uint end = (gl_WorkGroupID.x + 1) * group_size;
tmp[tid] = 0.0f;
// Calculate mean
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
tmp[tid] += float(data_a[col]);
}
// tmp up partial tmps and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
const float mean = tmp[0] / group_size;
barrier();
tmp[tid] = 0.0f;
// Calculate variance
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
const float xi = float(data_a[col]) - mean;
data_d[col] = D_TYPE(xi);
tmp[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
const float variance = tmp[0] / group_size;
const float scale = inversesqrt(variance + eps);
[[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) {
data_d[col] *= D_TYPE(scale);
}
}

View File

@@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}

View File

@@ -0,0 +1,22 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}

View File

@@ -0,0 +1,103 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.glsl"
#include "types.glsl"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint batch_offset; uint offset_delta;
uint IC;
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint pelements;
uint CHW;
int s0; int s1;
int p0; int p1;
int d0; int d1;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
const uint NUM_ITER = 512 / BLOCK_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
const uint base_linear_idx = gidx * NUM_ITER;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
A_TYPE values[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
#endif
}
}

View File

@@ -0,0 +1,125 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "rte.glsl"
#include "types.glsl"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint32_t i = gl_GlobalInvocationID.x;
uint32_t nb10 = p.nb10;
uint32_t nb11 = p.nb11;
uint32_t nb12 = p.nb12;
uint32_t nb13 = p.nb13;
uint32_t s0 = p.s0;
uint32_t s1 = p.s1;
uint32_t s2 = p.s2;
uint32_t p0 = p.p0;
uint32_t p1 = p.p1;
uint32_t p2 = p.p2;
uint32_t d0 = p.d0;
uint32_t d1 = p.d1;
uint32_t d2 = p.d2;
uint32_t IW = p.IW;
uint32_t IH = p.IH;
uint32_t ID = p.ID;
uint32_t IC = p.IC;
uint32_t KW = p.KW;
uint32_t OH = p.OH;
uint32_t KD_KH_KW = p.KD_KH_KW;
uint32_t KH_KW = p.KH_KW;
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
uint32_t N_OD_OH = p.N_OD_OH;
uint32_t OD_OH = p.OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
if (i >= IC_KD_KH_KW) {
return;
}
const uint32_t iic = i / KD_KH_KW;
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const uint32_t ikw = i % KW;
const uint32_t iow = gl_GlobalInvocationID.y;
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
const uint32_t in_ = iz / OD_OH;
const uint32_t iod = (iz - in_*OD_OH) / OH;
const uint32_t ioh = iz % OH;
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
if (iih >= IH || iiw >= IW || iid >= ID) {
dst_addr.d = D_TYPE(0.0f);
} else {
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#else
if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#endif
}
}

Some files were not shown because too many files have changed in this diff Show More