Compare commits

..

3 Commits

Author SHA1 Message Date
ParthSareen
cea61b0434 fix model loading 2025-10-24 15:03:39 -07:00
ParthSareen
b883c895d5 remove unused stuff cmd 2025-10-16 16:16:39 -07:00
ParthSareen
7fc9560414 remove generate path from ollama run 2025-10-13 16:40:59 -07:00
211 changed files with 526 additions and 33561 deletions

View File

@@ -237,13 +237,13 @@ jobs:
include:
- os: linux
arch: amd64
target: archive_novulkan
target: archive
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive_novulkan
target: archive
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -299,14 +299,12 @@ jobs:
include:
- os: linux
arch: arm64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
@@ -319,14 +317,6 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
- os: linux
arch: amd64
suffix: '-vulkan'
target: default
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -344,7 +334,6 @@ jobs:
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.target }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest

View File

@@ -52,12 +52,6 @@ jobs:
container: rocm/dev-ubuntu-22.04:6.1.2
extra-packages: rocm-libs
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm'
- preset: Vulkan
container: ubuntu:22.04
extra-packages: >
mesa-vulkan-drivers vulkan-tools
libvulkan1 libvulkan-dev
vulkan-sdk cmake ccache g++ make
runs-on: linux
container: ${{ matrix.container }}
steps:
@@ -65,19 +59,7 @@ jobs:
- run: |
[ -n "${{ matrix.container }}" ] || sudo=sudo
$sudo apt-get update
# Add LunarG Vulkan SDK apt repo for Ubuntu 22.04
if [ "${{ matrix.preset }}" = "Vulkan" ]; then
$sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg
# Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY
echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null
$sudo apt-get update
fi
$sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }}
# Export VULKAN_SDK if provided by LunarG package (defensive)
if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then
echo "VULKAN_SDK=/usr" >> $GITHUB_ENV
fi
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/cache@v4
@@ -110,21 +92,18 @@ jobs:
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
- preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
runs-on: windows
steps:
- run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan'
- if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm'
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: matrix.preset == 'CUDA'
name: Install CUDA ${{ matrix.cuda-version }}
@@ -154,18 +133,6 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4
with:

View File

@@ -139,15 +139,3 @@ if(CMAKE_HIP_COMPILER)
endforeach()
endif()
endif()
find_package(Vulkan)
if(Vulkan_FOUND)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan)
install(TARGETS ggml-vulkan
RUNTIME_DEPENDENCIES
PRE_INCLUDE_REGEXES vulkan
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan
)
endif()

View File

@@ -30,7 +30,7 @@
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;60;61;70;75;80;86;89;90;90a;120",
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
@@ -38,7 +38,7 @@
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2"
}
},
@@ -70,10 +70,6 @@
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
},
{
"name": "Vulkan",
"inherits": [ "Default" ]
}
],
"buildPresets": [
@@ -126,11 +122,6 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"configurePreset": "ROCm 6"
},
{
"name": "Vulkan",
"targets": [ "ggml-vulkan" ],
"configurePreset": "Vulkan"
}
]
}

View File

@@ -7,7 +7,6 @@ ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2
ARG VULKANVERSION=1.4.321.1
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
@@ -18,16 +17,6 @@ RUN yum install -y yum-utils \
&& dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
ARG VULKANVERSION
RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \
&& dnf -y install ninja-build \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \
&& /${VULKANVERSION}/vulkansdk -j 8 shaderc
RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \
&& cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib
ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache
@@ -117,13 +106,6 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \
&& cmake --build --parallel --preset 'Vulkan' \
&& cmake --install build --component Vulkan --strip --parallel 8
FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama
COPY go.mod go.sum .
@@ -141,8 +123,7 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
COPY --from=vulkan dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
@@ -155,40 +136,14 @@ FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
FROM ${FLAVOR} AS archive
ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
# Temporary opt-out stages for Vulkan
FROM --platform=linux/amd64 scratch AS amd64_novulkan
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
FROM arm64 AS arm64_novulkan
FROM ${FLAVOR}_novulkan AS archive_novulkan
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 AS novulkan
FROM ubuntu:24.04
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive_novulkan /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434
EXPOSE 11434
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM ubuntu:24.04 AS default
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive /lib/ollama /usr/lib/ollama

View File

@@ -544,7 +544,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
- [achatbot-go](https://github.com/ai-bot-pro/achatbot-go) a multimodal(text/audio/image) chatbot.
### Mobile

View File

@@ -204,7 +204,7 @@ type ToolCall struct {
}
type ToolCallFunction struct {
Index int `json:"index"`
Index int `json:"index,omitempty"`
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
@@ -266,9 +266,9 @@ func (pt PropertyType) String() string {
type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"`
Type PropertyType `json:"type"`
Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"`
Description string `json:"description"`
Enum []any `json:"enum,omitempty"`
}
@@ -332,7 +332,7 @@ func (t *ToolFunctionParameters) String() string {
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
Description string `json:"description"`
Parameters ToolFunctionParameters `json:"parameters"`
}

View File

@@ -298,30 +298,6 @@ func TestToolFunction_UnmarshalJSON(t *testing.T) {
}
}
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{
Name: "echo",
Arguments: ToolCallFunctionArguments{"message": "hi"},
}
data, err := json.Marshal(fn)
require.NoError(t, err)
raw := map[string]any{}
require.NoError(t, json.Unmarshal(data, &raw))
require.Contains(t, raw, "index")
assert.Equal(t, float64(0), raw["index"])
fn.Index = 3
data, err = json.Marshal(fn)
require.NoError(t, err)
raw = map[string]any{}
require.NoError(t, json.Unmarshal(data, &raw))
require.Contains(t, raw, "index")
assert.Equal(t, float64(3), raw["index"])
}
func TestPropertyType_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string

View File

@@ -473,7 +473,34 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
local := opts.Copy()
if local.MultiModal {
prompt, images, err := extractFileData(local.Prompt)
if err != nil {
return err
}
local.Prompt = prompt
local.Images = images
}
messages := slices.Clone(local.Messages)
if local.System != "" {
hasSystem := len(messages) > 0 && messages[0].Role == "system"
if !hasSystem {
messages = append([]api.Message{{Role: "system", Content: local.System}}, messages...)
}
}
if local.Prompt != "" || len(local.Images) > 0 {
messages = append(messages, api.Message{Role: "user", Content: local.Prompt, Images: local.Images})
}
local.Messages = messages
_, err = chat(cmd, local)
return err
}
func SigninHandler(cmd *cobra.Command, args []string) error {
@@ -1099,8 +1126,6 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return client.Pull(cmd.Context(), &request, fn)
}
type generateContextKey string
type runOptions struct {
Model string
ParentModel string
@@ -1363,137 +1388,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
var latest api.GenerateResponse
generateContext, ok := cmd.Context().Value(generateContextKey("context")).([]int)
if !ok {
generateContext = []int{}
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var thinkingContent strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
fn := func(response api.GenerateResponse) error {
latest = response
content := response.Response
if response.Response != "" || !opts.HideThinking {
p.StopAndClear()
}
if response.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(plainText))
thinkTagOpened = true
thinkTagClosed = false
}
thinkingContent.WriteString(response.Thinking)
displayResponse(response.Thinking, opts.WordWrap, state)
}
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) {
if !strings.HasSuffix(thinkingContent.String(), "\n") {
fmt.Println()
}
fmt.Print(thinkingOutputClosingText(plainText))
thinkTagOpened = false
thinkTagClosed = true
state = &displayResponseState{}
}
displayResponse(content, opts.WordWrap, state)
if response.ToolCalls != nil {
toolCalls := response.ToolCalls
if len(toolCalls) > 0 {
fmt.Print(renderToolCalls(toolCalls, plainText))
}
}
return nil
}
if opts.MultiModal {
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
if err != nil {
return err
}
}
if opts.Format == "json" {
opts.Format = `"` + opts.Format + `"`
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Images: opts.Images,
Format: json.RawMessage(opts.Format),
System: opts.System,
Options: opts.Options,
KeepAlive: opts.KeepAlive,
Think: opts.Think,
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
}
return err
}
if opts.Prompt != "" {
fmt.Println()
fmt.Println()
}
if !latest.Done {
return nil
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return err
}
if verbose {
latest.Summary()
}
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)
return nil
}
func RunServer(_ *cobra.Command, _ []string) error {
if err := initializeKeypair(); err != nil {
return err

View File

@@ -18,7 +18,6 @@ import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
)
@@ -340,8 +339,13 @@ func TestConvertAdapter(t *testing.T) {
}
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
if diff := cmp.Diff(c.Expected, actual); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k)
} else if v != c.Expected[k] {
t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
}
}
})
}

View File

@@ -70,7 +70,6 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
if dev.Library == "ROCm" && rocmDir != "" {
info.DependencyPath = append(info.DependencyPath, rocmDir)
}
// TODO any special processing of Vulkan devices?
resp = append(resp, info)
}
if len(resp) == 0 {
@@ -98,16 +97,7 @@ func (l GpuInfoList) GetVisibleDevicesEnv() []string {
if len(l) == 0 {
return nil
}
res := []string{}
envVar := rocmGetVisibleDevicesEnv(l)
if envVar != "" {
res = append(res, envVar)
}
envVar = vkGetVisibleDevicesEnv(l)
if envVar != "" {
res = append(res, envVar)
}
return res
return []string{rocmGetVisibleDevicesEnv(l)}
}
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
@@ -137,25 +127,6 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
return envVar + strings.Join(ids, ",")
}
func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "Vulkan" {
continue
}
if info.filterID != "" {
ids = append(ids, info.filterID)
} else {
ids = append(ids, info.ID)
}
}
if len(ids) == 0 {
return ""
}
envVar := "GGML_VK_VISIBLE_DEVICES="
return envVar + strings.Join(ids, ",")
}
// GetSystemInfo returns the last cached state of the GPUs on the system
func GetSystemInfo() SystemInfo {
deviceMu.Lock()

View File

@@ -86,9 +86,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
// are enumerated, but not actually supported.
// We run this in serial to avoid potentially initializing a GPU multiple
// times concurrently leading to memory contention
// TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs
for dir := range libDirs {
bootstrapTimeout := 30 * time.Second
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
@@ -103,16 +101,11 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
} else {
dirs = []string{LibOllamaPath, dir}
}
// ROCm can take a long time on some systems, so give it more time before giving up
if dir != "" && strings.Contains(filepath.Base(dir), "rocm") {
bootstrapTimeout = 60 * time.Second
}
// Typically bootstrapping takes < 1s, but on some systems, with devices
// in low power/idle mode, initialization can take multiple seconds. We
// set a long timeout just for bootstrap discovery to reduce the chance
// of giving up too quickly
ctx1stPass, cancel := context.WithTimeout(ctx, bootstrapTimeout)
ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// For this pass, we retain duplicates in case any are incompatible with some libraries
@@ -138,25 +131,19 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
go func(i int) {
defer wg.Done()
var envVar string
id := devices[i].ID
if devices[i].Library == "ROCm" {
if runtime.GOOS != "linux" {
envVar = "HIP_VISIBLE_DEVICES"
} else {
envVar = "ROCR_VISIBLE_DEVICES"
}
} else if devices[i].Library == "CUDA" {
envVar = "CUDA_VISIBLE_DEVICES"
} else if devices[i].Library == "Vulkan" {
id = devices[i].FilteredID
envVar = "GGML_VK_VISIBLE_DEVICES"
} else {
slog.Error("Unknown Library:" + devices[i].Library)
envVar = "CUDA_VISIBLE_DEVICES"
}
extraEnvs := []string{
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
envVar + "=" + id, // Filter to just this one GPU
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
envVar + "=" + devices[i].ID, // Filter to just this one GPU
}
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
needsDelete[i] = true
@@ -176,8 +163,6 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
wg.Wait()
logutil.Trace("supported GPU library combinations", "supported", supported)
filterOutVulkanThatAreSupportedByOtherGPU(needsDelete)
// Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible
filterOverlapByLibrary(supported, needsDelete)
@@ -199,7 +184,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
}
// Now filter out any overlap with different libraries (favor CUDA/HIP over others)
// Now filter out any overlap with different libraries (favor CUDA/ROCm over others)
for i := 0; i < len(devices); i++ {
for j := i + 1; j < len(devices); j++ {
// For this pass, we only drop exact duplicates
@@ -355,38 +340,10 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
}
return devices
}
// Apply any iGPU workarounds
iGPUWorkarounds(devices)
func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) {
// Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion
for i := range devices {
if devices[i].Library != "Vulkan" || needsDelete[i] {
continue
}
if devices[i].PCIID == "" {
continue
}
for j := range devices {
if i == j {
continue
}
if devices[j].PCIID == "" {
continue
}
if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] {
needsDelete[i] = true
slog.Debug("dropping Vulkan duplicate by PCI ID",
"vulkan_id", devices[i].ID,
"vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1],
"pci_id", devices[i].PCIID,
"kept_library", devices[j].Library,
"kept_id", devices[j].ID,
)
break
}
}
}
return devices
}
func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) {
@@ -494,7 +451,6 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
}
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
pathNeeded := true
@@ -552,7 +508,6 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
}
}
logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices)
return devices
}
@@ -604,3 +559,32 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceIn
}
}
}
func iGPUWorkarounds(devices []ml.DeviceInfo) {
// short circuit if we have no iGPUs
anyiGPU := false
for i := range devices {
if devices[i].Integrated {
anyiGPU = true
break
}
}
if !anyiGPU {
return
}
memInfo, err := GetCPUMem()
if err != nil {
slog.Debug("failed to fetch system memory information for iGPU", "error", err)
return
}
for i := range devices {
if !devices[i].Integrated {
continue
}
// NVIDIA iGPUs return useless free VRAM data which ignores system buff/cache
if devices[i].Library == "CUDA" {
devices[i].FreeMemory = memInfo.FreeMemory
}
}
}

View File

@@ -37,7 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
UnreliableFreeMemory bool
// GPU information
filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
ComputeMinor int `json:"compute_minor"`
@@ -175,8 +175,7 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
gpu.Library == "ROCm"
if !supportsFA {
return false

View File

@@ -9,20 +9,15 @@ Check your compute compatibility to see if your card is supported:
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 11.0 | Jetson | `T4000` `T5000` (Requires driver 580 or newer) |
| 10.3 | NVIDIA Professioal | `B300` `GB300` (Requires driver 580 or newer) |
| 10.0 | NVIDIA Professioal | `B200` `GB200` (Requires driver 580 or newer) |
| 9.0 | NVIDIA | `H200` `H100` `GH200` |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
| 8.7 | Jetson | `Orin Nano` `Orin NX` `AGX Orin` |
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
| 8.0 | NVIDIA | `A100` `A30` |
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
| 7.2 | Jetson | `Xavier NX` `AGX Xavier` (Jetpack 5) |
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |

View File

@@ -24,9 +24,6 @@ func Host() *url.URL {
switch {
case !ok:
scheme, hostport = "http", s
if s == "ollama.com" {
scheme, hostport = "https", "ollama.com:443"
}
case scheme == "http":
defaultPort = "80"
case scheme == "https":
@@ -220,7 +217,6 @@ var (
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES")
VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES")
GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL")
HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION")
)
@@ -311,7 +307,6 @@ func AsMap() map[string]EnvVar {
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"}
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"}
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}

View File

@@ -37,7 +37,6 @@ func TestHost(t *testing.T) {
"https": {"https://1.2.3.4", "https://1.2.3.4:443"},
"https port": {"https://1.2.3.4:4321", "https://1.2.3.4:4321"},
"proxy path": {"https://example.com/ollama", "https://example.com:443/ollama"},
"ollama.com": {"ollama.com", "https://ollama.com:443"},
}
for name, tt := range cases {

View File

@@ -893,7 +893,6 @@ func (f GGML) SupportsFlashAttention() bool {
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gemma3",
"gptoss", "gpt-oss",
"qwen3",
"qwen3moe",

View File

@@ -509,10 +509,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
}
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
}
alignment := kv.Uint("general.alignment", 32)
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
@@ -531,7 +528,7 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
}
for _, key := range slices.Sorted(maps.Keys(kv)) {
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err
}
}
@@ -546,8 +543,6 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
},
)
alignment := kv.Uint("general.alignment", 32)
var s uint64
for i := range ts {
ts[i].Offset = s
@@ -579,14 +574,7 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return g.Wait()
}
func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
if !strings.HasPrefix(k, arch+".") &&
!strings.HasPrefix(k, "general.") &&
!strings.HasPrefix(k, "adapter.") &&
!strings.HasPrefix(k, "tokenizer.") {
k = arch + "." + k
}
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
slog.Debug(k, "type", fmt.Sprintf("%T", v))
if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
return err

View File

@@ -39,12 +39,7 @@ func TestWriteGGUF(t *testing.T) {
defer w.Close()
if err := WriteGGUF(w, KV{
"general.architecture": "test",
"general.alignment": uint32(16),
"test.key": "value",
"attention.key": "value2",
"tokenizer.key": "value3",
"adapter.key": "value4",
"general.alignment": uint32(16),
}, ts); err != nil {
t.Fatal(err)
}
@@ -61,19 +56,14 @@ func TestWriteGGUF(t *testing.T) {
}
if diff := cmp.Diff(KV{
"general.architecture": "test",
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
"test.key": "value",
"test.attention.key": "value2",
"tokenizer.key": "value3",
"adapter.key": "value4",
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 800,
Offset: 592,
items: []*Tensor{
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},

View File

@@ -229,7 +229,7 @@ const (
TensorTypeMXFP4
)
// ParseTensorType parses the provided GGUF tensor type
// ParseFileType parses the provided GGUF file type
// Only Ollama supported types are considered valid
func ParseTensorType(s string) (TensorType, error) {
switch s {

View File

@@ -109,8 +109,6 @@ func TestMultiModelStress(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
initialTimeout := 120 * time.Second
streamTimeout := 20 * time.Second
// Make sure all the models are pulled before we get started
for _, model := range chosenModels {
@@ -149,8 +147,6 @@ chooseModels:
for _, m := range models.Models {
if m.SizeVRAM == 0 {
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
initialTimeout = 240 * time.Second
streamTimeout = 30 * time.Second
break chooseModels
}
}
@@ -176,7 +172,10 @@ chooseModels:
k := r.Int() % len(reqs)
reqs[k].Model = chosenModels[i]
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
DoChat(ctx, t, client, reqs[k], resps[k], initialTimeout, streamTimeout)
DoChat(ctx, t, client, reqs[k], resps[k],
120*time.Second, // Be extra patient for the model to load initially
10*time.Second, // Once results start streaming, fail if they stall
)
}
}(i)
}

View File

@@ -78,7 +78,7 @@ func TestContextExhaustion(t *testing.T) {
// Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestParallelGenerateWithHistory(t *testing.T) {
modelName := "gpt-oss:20b"
modelOverride := "gpt-oss:20b"
req, resp := GenerateRequests()
numParallel := 2
iterLimit := 2
@@ -88,23 +88,15 @@ func TestParallelGenerateWithHistory(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
initialTimeout := 120 * time.Second
streamTimeout := 20 * time.Second
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", modelName)
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelName, err)
}
gpuPercent := getGPUPercent(ctx, t, client, modelName)
if gpuPercent < 80 {
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
initialTimeout = 240 * time.Second
streamTimeout = 30 * time.Second
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
@@ -113,7 +105,7 @@ func TestParallelGenerateWithHistory(t *testing.T) {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelName
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
@@ -122,7 +114,7 @@ func TestParallelGenerateWithHistory(t *testing.T) {
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
c := DoGenerate(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
req[k].Context = c
req[k].Prompt = "tell me more!"
}
@@ -173,7 +165,7 @@ func TestGenerateWithHistory(t *testing.T) {
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestParallelChatWithHistory(t *testing.T) {
modelName := "gpt-oss:20b"
modelOverride := "gpt-oss:20b"
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2
@@ -183,23 +175,15 @@ func TestParallelChatWithHistory(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
initialTimeout := 120 * time.Second
streamTimeout := 20 * time.Second
// Get the server running (if applicable) warm the model up with a single initial empty request
slog.Info("loading", "model", modelName)
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelName, err)
}
gpuPercent := getGPUPercent(ctx, t, client, modelName)
if gpuPercent < 80 {
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
initialTimeout = 240 * time.Second
streamTimeout = 30 * time.Second
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
@@ -208,7 +192,7 @@ func TestParallelChatWithHistory(t *testing.T) {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelName
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
@@ -217,7 +201,7 @@ func TestParallelChatWithHistory(t *testing.T) {
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
assistant := DoChat(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}

View File

@@ -65,23 +65,6 @@ func TestModelsChat(t *testing.T) {
}
}
}
initialTimeout := 120 * time.Second
streamTimeout := 30 * time.Second
slog.Info("loading", "model", model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", model, err)
}
gpuPercent := getGPUPercent(ctx, t, client, model)
if gpuPercent < 80 {
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
initialTimeout = 240 * time.Second
streamTimeout = 40 * time.Second
}
// TODO - fiddle with context size
req := api.ChatRequest{
Model: model,
@@ -97,7 +80,7 @@ func TestModelsChat(t *testing.T) {
"seed": 123,
},
}
DoChat(ctx, t, client, req, blueSkyExpected, initialTimeout, streamTimeout)
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
})

File diff suppressed because one or more lines are too long

View File

@@ -743,13 +743,6 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
gpuPercent := getGPUPercent(ctx, t, client, model)
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
}
func getGPUPercent(ctx context.Context, t *testing.T, client *api.Client, model string) int {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
@@ -779,10 +772,12 @@ func getGPUPercent(ctx context.Context, t *testing.T, client *api.Client, model
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
gpuPercent = int(100 - cpuPercent)
}
return gpuPercent
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
return
}
t.Fatalf("model %s not loaded - actually loaded: %v", model, loaded)
return 0
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {

View File

@@ -267,12 +267,10 @@ static struct llama_model * llama_model_load_from_file_impl(
for (auto * dev : model->devices) {
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
size_t memory_free, memory_total;
ggml_backend_dev_memory(dev, &memory_free, &memory_total);
LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__,
ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
props.device_id ? props.device_id : "unknown id",
memory_free/1024/1024);
props.memory_free/1024/1024);
}
const int status = llama_model_load(path_model, splits, *model, params);

View File

@@ -69,9 +69,7 @@ func EnumerateGPUs() []ml.DeviceID {
for i := range C.ggml_backend_dev_count() {
device := C.ggml_backend_dev_get(i)
switch C.ggml_backend_dev_type(device) {
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
var props C.struct_ggml_backend_dev_props
C.ggml_backend_dev_get_props(device, &props)
ids = append(ids, ml.DeviceID{

View File

@@ -12,11 +12,10 @@ unused then it can be reset to free these data structures.
ggml/src/ggml-backend.cpp | 8 ++++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 16 +++++++++++++++-
ggml/src/ggml-cuda/vendors/hip.h | 1 +
src/llama.cpp | 4 +++-
6 files changed, 32 insertions(+), 2 deletions(-)
5 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 1ff53ed03..ba181d09d 100644
index 1ff53ed0..ba181d09 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -178,6 +178,7 @@ extern "C" {
@@ -28,7 +27,7 @@ index 1ff53ed03..ba181d09d 100644
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 3c3f22fc0..43c91d9f2 100644
index 3c3f22fc..43c91d9f 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -195,6 +195,10 @@ extern "C" {
@@ -43,7 +42,7 @@ index 3c3f22fc0..43c91d9f2 100644
struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 6ef5eeafa..0b757af59 100644
index 6ef5eeaf..0b757af5 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -526,6 +526,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
@@ -62,7 +61,7 @@ index 6ef5eeafa..0b757af59 100644
GGML_ASSERT(device);
return device->iface.get_buffer_type(device);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 811462c79..87c6c34a4 100644
index 811462c7..87c6c34a 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -107,6 +107,11 @@ int ggml_cuda_get_device() {
@@ -110,7 +109,7 @@ index 811462c79..87c6c34a4 100644
// backend reg
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 890c10364..1f06be80e 100644
index 890c1036..1f06be80 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -45,6 +45,7 @@
@@ -121,21 +120,3 @@ index 890c10364..1f06be80e 100644
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
diff --git a/src/llama.cpp b/src/llama.cpp
index fe5a7a835..d821a96a0 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -267,10 +267,12 @@ static struct llama_model * llama_model_load_from_file_impl(
for (auto * dev : model->devices) {
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
+ size_t memory_free, memory_total;
+ ggml_backend_dev_memory(dev, &memory_free, &memory_total);
LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__,
ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
props.device_id ? props.device_id : "unknown id",
- props.memory_free/1024/1024);
+ memory_free/1024/1024);
}
const int status = llama_model_load(path_model, splits, *model, params);

View File

@@ -6,23 +6,23 @@ Subject: [PATCH] GPU discovery enhancements
Expose more information about the devices through backend props, and leverage
management libraries for more accurate VRAM usage reporting if available.
---
ggml/include/ggml-backend.h | 11 +
ggml/include/ggml-backend.h | 9 +
ggml/src/CMakeLists.txt | 2 +
ggml/src/ggml-cuda/ggml-cuda.cu | 74 +++++
ggml/src/ggml-cuda/ggml-cuda.cu | 72 +++++
ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 ++++++++++++++
8 files changed, 758 insertions(+)
8 files changed, 754 insertions(+)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index ba181d09d..094fc3c82 100644
index ba181d09..09ff75f9 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -169,6 +169,17 @@ extern "C" {
@@ -169,6 +169,15 @@ extern "C" {
const char * device_id;
// device capabilities
struct ggml_backend_dev_caps caps;
@@ -35,13 +35,11 @@ index ba181d09d..094fc3c82 100644
+ int pci_device_id;
+ int pci_domain_id;
+ const char *library;
+ // number with which the devices are accessed (Vulkan)
+ const char *numeric_id;
};
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 0609c6503..aefe43bdd 100644
index 0609c650..aefe43bd 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -209,6 +209,8 @@ add_library(ggml-base
@@ -54,7 +52,7 @@ index 0609c6503..aefe43bdd 100644
target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 87c6c34a4..816597d2f 100644
index 87c6c34a..6a278b5e 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -161,23 +159,21 @@ index 87c6c34a4..816597d2f 100644
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
bool events = false;
@@ -4087,6 +4149,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
@@ -4087,6 +4149,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
+ int driverVersion = 0;
+ CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
@@ -4102,6 +4165,17 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
@@ -4102,6 +4166,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
dev_ctx->pci_bus_id = pci_bus_id;
+ dev_ctx->major = prop.major;
+ dev_ctx->minor = prop.minor;
+ if (driverVersion == 0) {
+ CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
+ }
+ dev_ctx->driver_major = driverVersion / 1000;
+ dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10;
+ dev_ctx->integrated = prop.integrated;
@@ -188,7 +184,7 @@ index 87c6c34a4..816597d2f 100644
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 1f06be80e..2f9ef2dc0 100644
index 1f06be80..2f9ef2dc 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -5,6 +5,8 @@
@@ -209,7 +205,7 @@ index 1f06be80e..2f9ef2dc0 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index d0fb3bcca..80597b6ea 100644
index d0fb3bcc..80597b6e 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -638,6 +638,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
@@ -228,7 +224,7 @@ index d0fb3bcca..80597b6ea 100644
}
#endif
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index f2ff9f322..f356e4a0a 100644
index f2ff9f32..f356e4a0 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
@@ -249,7 +245,7 @@ index f2ff9f322..f356e4a0a 100644
/* .host_buffer = */ false,
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 000000000..8ef19b8cf
index 00000000..8ef19b8c
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,449 @@
@@ -705,7 +701,7 @@ index 000000000..8ef19b8cf
\ No newline at end of file
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
new file mode 100644
index 000000000..c9073cef0
index 00000000..c9073cef
--- /dev/null
+++ b/ggml/src/mem_nvml.cpp
@@ -0,0 +1,209 @@

View File

@@ -1,95 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Xiaodong Ye <xiaodong.ye@mthreads.com>
Date: Mon, 18 Aug 2025 12:48:07 +0800
Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5)
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++
1 file changed, 37 insertions(+)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 061cd078..adea7783 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_
snprintf(description, description_size, "%s", props.deviceName.data());
}
+static std::string ggml_vk_get_device_id(int device) {
+ ggml_vk_instance_init();
+
+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
+
+ vk::PhysicalDeviceProperties2 props;
+ vk::PhysicalDeviceIDProperties deviceIDProps;
+ props.pNext = &deviceIDProps;
+ devices[device].getProperties2(&props);
+
+ const auto& uuid = deviceIDProps.deviceUUID;
+ char id[64];
+ snprintf(id, sizeof(id),
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
+ uuid[0], uuid[1], uuid[2], uuid[3],
+ uuid[4], uuid[5],
+ uuid[6], uuid[7],
+ uuid[8], uuid[9],
+ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15]
+ );
+ return std::string(id);
+}
+
// backend interface
#define UNUSED GGML_UNUSED
@@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
ggml_vk_get_device_description(dev_idx, description, description_size);
}
+std::string ggml_backend_vk_get_device_id(int device) {
+ GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+ int dev_idx = vk_instance.device_indices[device];
+ return ggml_vk_get_device_id(dev_idx);
+}
+
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context {
std::string description;
bool is_integrated_gpu;
std::string pci_bus_id;
+ std::string id;
};
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de
return ctx->description.c_str();
}
+static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+ return ctx->id.c_str();
+}
+
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
ggml_backend_vk_get_device_memory(ctx->device, free, total);
@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev);
+ props->id = ggml_backend_vk_device_get_id(dev);
props->type = ggml_backend_vk_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+ ctx->id = ggml_backend_vk_get_device_id(i);
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
--
2.51.0

View File

@@ -1,254 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri Sep 5 08:25:03 2025 -0700
Subject: [PATCH] Vulkan PCI and Memory
---
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++-----
1 file changed, 145 insertions(+), 31 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index adea7783..fb7204ce 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) {
return ggml_vk_get_device_id(dev_idx);
}
-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
- GGML_ASSERT(device < (int) vk_instance.device_indices.size());
- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
+//////////////////////////
+
+struct ggml_backend_vk_device_context {
+ size_t device;
+ std::string name;
+ std::string description;
+ bool is_integrated_gpu;
+ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function)
+ std::string pci_id;
+ std::string id;
+ std::string uuid;
+ int major;
+ int minor;
+ int driver_major;
+ int driver_minor;
+ int pci_bus_id;
+ int pci_device_id;
+ int pci_domain_id;
+};
+
+void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) {
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size());
+ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size());
+
+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]];
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
- vk::PhysicalDeviceMemoryProperties2 memprops = {};
- bool membudget_supported = vk_instance.device_supports_membudget[device];
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+ vk::PhysicalDeviceProperties2 props2;
+ vkdev.getProperties2(&props2);
- if (membudget_supported) {
- memprops.pNext = &budgetprops;
+ if (!ctx->is_integrated_gpu)
+ {
+ // Use vendor specific management libraries for best VRAM reporting if available
+ switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+ ggml_hip_mgmt_release();
+ return;
+ }
+ ggml_hip_mgmt_release();
+ }
+ break;
+ case VK_VENDOR_ID_NVIDIA:
+ if (ggml_nvml_init() == 0) {
+ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total);
+ ggml_nvml_release();
+ return;
+ }
+ ggml_nvml_release();
+ }
+ break;
+ }
}
- vkdev.getMemoryProperties2(&memprops);
+ // else fallback to memory budget if supported
- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
+ *total = 0;
+ *free = 0;
+ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props;
+ vk::PhysicalDeviceMemoryProperties2 memprops2;
+ memprops2.pNext = &mem_budget_props;
+ vkdev.getMemoryProperties2(&memprops2);
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ } else if (ctx->is_integrated_gpu) {
+ // Include shared memory on iGPUs
+ *total += memprops2.memoryProperties.memoryHeaps[i].size;
+ }
+ }
+ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) {
+ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+ *free += mem_budget_props.heapBudget[i];
+ } else if (ctx->is_integrated_gpu) {
+ *free += mem_budget_props.heapBudget[i];
+ }
+ }
+ if (*total > 0 && *free > 0) {
+ return;
+ } else if (*total > 0) {
+ *free = *total;
+ return;
+ }
+ // else just report the physical memory
+ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) {
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
*total = heap.size;
-
- if (membudget_supported && i < budgetprops.heapUsage.size()) {
- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
- } else {
- *free = heap.size;
- }
+ *free = heap.size;
break;
}
}
@@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
return std::string(pci_bus_id);
}
-//////////////////////////
-
-struct ggml_backend_vk_device_context {
- size_t device;
- std::string name;
- std::string description;
- bool is_integrated_gpu;
- std::string pci_bus_id;
- std::string id;
-};
+static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) {
+ if (id.empty()) return false;
+ unsigned int d = 0, b = 0, dev = 0, func = 0;
+ // Expected format: dddd:bb:dd.f (all hex)
+ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func);
+ if (n < 4) return false;
+ if (domain) *domain = (int) d;
+ if (bus) *bus = (int) b;
+ if (device) *device = (int) dev;
+ return true;
+}
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
@@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) {
static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
- ggml_backend_vk_get_device_memory(ctx->device, free, total);
+ ggml_backend_vk_get_device_memory(ctx, free, total);
}
static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->description = ggml_backend_vk_device_get_description(dev);
props->id = ggml_backend_vk_device_get_id(dev);
props->type = ggml_backend_vk_device_get_type(dev);
- props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
+ props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str();
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
@@ -12564,6 +12633,17 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
/* .buffer_from_host_ptr = */ false,
/* .events = */ false,
};
+
+ props->compute_major = ctx->major;
+ props->compute_minor = ctx->minor;
+ props->driver_major = ctx->driver_major;
+ props->driver_minor = ctx->driver_minor;
+ props->integrated = ctx->is_integrated_gpu;
+ props->pci_bus_id = ctx->pci_bus_id;
+ props->pci_device_id = ctx->pci_device_id;
+ props->pci_domain_id = ctx->pci_domain_id;
+ props->library = GGML_VK_NAME;
+ props->numeric_id = ctx->id.empty() ? nullptr : ctx->id.c_str();
}
static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
@@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (!initialized) {
+ std::vector<vk::PhysicalDevice> vk_devices = vk_instance.instance.enumeratePhysicalDevices();
+
for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
char desc[256];
@@ -13000,13 +13081,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
ctx->name = GGML_VK_NAME + std::to_string(i);
ctx->description = desc;
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
- ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
+ ctx->pci_id = ggml_backend_vk_get_device_pci_id(i);
ctx->id = ggml_backend_vk_get_device_id(i);
devices.push_back(new ggml_backend_device {
/* .iface = */ ggml_backend_vk_device_i,
/* .reg = */ reg,
/* .context = */ ctx,
});
+
+ // Gather additional information about the device
+ int dev_idx = vk_instance.device_indices[i];
+ vk::PhysicalDeviceProperties props1;
+ vk_devices[dev_idx].getProperties(&props1);
+ vk::PhysicalDeviceProperties2 props2;
+ vk::PhysicalDeviceIDProperties device_id_props;
+ vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props;
+ vk::PhysicalDeviceDriverProperties driver_props;
+ props2.pNext = &device_id_props;
+ device_id_props.pNext = &pci_bus_props;
+ pci_bus_props.pNext = &driver_props;
+ vk_devices[dev_idx].getProperties2(&props2);
+ std::ostringstream oss;
+ oss << std::hex << std::setfill('0');
+ oss << "GPU-";
+ int byteIdx = 0;
+ for (int i = 0; i < 16; ++i, ++byteIdx) {
+ oss << std::setw(2) << static_cast<int>(device_id_props.deviceUUID[i]);
+ if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) {
+ oss << '-';
+ }
+ }
+ ctx->uuid = oss.str();
+ ctx->pci_bus_id = pci_bus_props.pciBus;
+ ctx->pci_device_id = pci_bus_props.pciDevice;
+ ctx->pci_domain_id = pci_bus_props.pciDomain;
+ ctx->id = std::to_string(i);
+ ctx->major = 0;
+ ctx->minor = 0;
+ // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string
+ ctx->driver_major = 0;
+ ctx->driver_minor = 0;
}
initialized = true;
}
--
2.51.0

View File

@@ -1,137 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Santosh Bhavani <santosh.bhavani@live.com>
Date: Wed, 15 Oct 2025 09:29:51 -0700
Subject: [PATCH] NVML fallback for unified memory GPUs
---
ggml/src/mem_nvml.cpp | 71 +++++++++++++++++++++++++++++++++++++++++--
1 file changed, 68 insertions(+), 3 deletions(-)
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
index c9073cef..f473a2a2 100644
--- a/ggml/src/mem_nvml.cpp
+++ b/ggml/src/mem_nvml.cpp
@@ -13,6 +13,7 @@
#include <filesystem>
#include <mutex>
#include <array>
+#include <cstring>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
@@ -23,6 +24,8 @@
#else
# include <dlfcn.h>
# include <unistd.h>
+# include <fstream>
+# include <string>
#endif
namespace fs = std::filesystem;
@@ -79,12 +82,36 @@ struct {
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
+ nvmlReturn_t (*nvmlDeviceGetName)(nvmlDevice_t, char *, unsigned int);
const char * (*nvmlErrorString)(nvmlReturn_t result);
-} nvml { NULL, NULL, NULL, NULL, NULL };
+} nvml { NULL, NULL, NULL, NULL, NULL, NULL, NULL };
static std::mutex ggml_nvml_lock;
extern "C" {
+#ifndef _WIN32
+// Helper function to get available memory from /proc/meminfo on Linux
+// Returns MemAvailable as calculated by the kernel
+static size_t get_mem_available() {
+ std::ifstream meminfo("/proc/meminfo");
+ if (!meminfo.is_open()) {
+ return 0;
+ }
+
+ std::string line;
+ while (std::getline(meminfo, line)) {
+ if (line.find("MemAvailable:") == 0) {
+ size_t available_kb;
+ sscanf(line.c_str(), "MemAvailable: %zu kB", &available_kb);
+ // Convert from kB to bytes
+ return available_kb * 1024;
+ }
+ }
+
+ return 0;
+}
+#endif
+
int ggml_nvml_init() {
std::lock_guard<std::mutex> lock(ggml_nvml_lock);
if (nvml.handle != NULL) {
@@ -117,8 +144,9 @@ int ggml_nvml_init() {
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
+ nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetName");
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
- if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL || nvml.nvmlErrorString == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
@@ -151,8 +179,9 @@ int ggml_nvml_init() {
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
+ nvml.nvmlDeviceGetName = (nvmlReturn_t (*)(nvmlDevice_t, char *, unsigned int)) dlsym(nvml.handle, "nvmlDeviceGetName");
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
- if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlDeviceGetName == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
dlclose(nvml.handle);
nvml.handle = NULL;
@@ -199,10 +228,46 @@ int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
}
nvmlMemory_t memInfo = {0};
status = nvml.nvmlDeviceGetMemoryInfo(device, &memInfo);
+
if (status == NVML_SUCCESS) {
+ // NVML working correctly, use its values
*free = memInfo.free;
*total = memInfo.total;
+ return NVML_SUCCESS;
}
+
+#ifndef _WIN32
+ // Handle NVML_ERROR_NOT_SUPPORTED - this indicates NVML doesn't support
+ // reporting framebuffer memory (e.g., unified memory GPUs where FB memory is 0)
+ if (status == NVML_ERROR_NOT_SUPPORTED) {
+ // Use system memory from /proc/meminfo
+ size_t mem_available = get_mem_available();
+ size_t mem_total = 0;
+
+ // Read MemTotal
+ std::ifstream meminfo("/proc/meminfo");
+ if (meminfo.is_open()) {
+ std::string line;
+ while (std::getline(meminfo, line)) {
+ if (line.find("MemTotal:") == 0) {
+ size_t total_kb;
+ sscanf(line.c_str(), "MemTotal: %zu kB", &total_kb);
+ mem_total = total_kb * 1024;
+ break;
+ }
+ }
+ }
+
+ if (mem_total > 0) {
+ *total = mem_total;
+ *free = mem_available;
+ GGML_LOG_INFO("%s NVML not supported for memory query, using system memory (total=%zu, available=%zu)\n",
+ __func__, mem_total, mem_available);
+ return NVML_SUCCESS;
+ }
+ }
+#endif
+
return status;
}

View File

@@ -1,49 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Julius Tischbein <ju.tischbein@gmail.com>
Date: Wed, 15 Oct 2025 13:54:15 +0200
Subject: [PATCH] CUDA: Changing the CUDA scheduling strategy to spin (#16585)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* CUDA set scheduling strategy to spinning for cc121
* Using prop.major and prop.minor, include HIP and MUSA
* Exclude HIP and MUSA
* Remove trailing whitespace
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
* Remove empty line
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---------
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
---
ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 6a278b5e9..87941f872 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -340,6 +340,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
+
+ // Temporary performance fix:
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
+ // TODO: Check for future drivers the default scheduling strategy and
+ // remove this call again when cudaDeviceScheduleSpin is default.
+ if (prop.major == 12 && prop.minor == 1) {
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
+ }
+
#endif // defined(GGML_USE_HIP)
}

View File

@@ -1,32 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri, 17 Oct 2025 14:17:00 -0700
Subject: [PATCH] report LoadLibrary failures
---
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index f794d9cfa..3a855ab2e 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -118,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
+ if (!handle) {
+ DWORD error_code = GetLastError();
+ std::string msg;
+ LPSTR lpMsgBuf = NULL;
+ DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
+ if (bufLen) {
+ msg = lpMsgBuf;
+ LocalFree(lpMsgBuf);
+ GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
+ }
+ }
SetErrorMode(old_mode);

View File

@@ -567,7 +567,6 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
(runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
(gpus[0].Library == "cpu" && s.options.UseMMap == nil) ||
(gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
s.loadRequest.UseMmap = false
}
@@ -928,7 +927,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
}
}
libraryGpuLayers := assignLayers(layers, gl, requireFull, s.options.NumGPU, lastUsedGPU)
libraryGpuLayers := assignLayers(layers, gl, s.options.NumGPU, lastUsedGPU)
if libraryGpuLayers.Sum() > gpuLayers.Sum() {
gpuLayers = libraryGpuLayers
}
@@ -994,7 +993,7 @@ nextLayer:
}
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
func assignLayers(layers []uint64, gpus discover.GpuInfoList, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
func assignLayers(layers []uint64, gpus discover.GpuInfoList, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
// If we can't fit everything then prefer offloading layers other than the output layer
for range 2 {
// requestedLayers may be -1 if nothing was requested
@@ -1003,14 +1002,14 @@ func assignLayers(layers []uint64, gpus discover.GpuInfoList, requireFull bool,
if !envconfig.SchedSpread() {
for i := lastUsedGPU; i < len(gpus); i++ {
// Try to pack things into as few GPUs as possible
forceRequest := i == len(gpus)-1 && !requireFull
forceRequest := i == len(gpus)-1
gpuLayers = findBestFit(layers, gpus[:i+1], requestedLayers, forceRequest)
if gpuLayers.Sum() == len(layers) || gpuLayers.Sum() == requestedLayers {
break
}
}
} else {
gpuLayers = findBestFit(layers, gpus, requestedLayers, !requireFull)
gpuLayers = findBestFit(layers, gpus, requestedLayers, true)
}
// We only stop if we've gotten all of the layers - even if we got requestedLayers, we still

View File

@@ -127,14 +127,6 @@ func TestLLMServerFitGPU(t *testing.T) {
requireFull: true,
expectedErr: ErrLoadRequiredFull,
},
{
name: "requireFull numGPU",
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
numGPU: 4,
requireFull: true,
expectedErr: ErrLoadRequiredFull,
},
}
for _, tt := range tests {

View File

@@ -57,8 +57,7 @@ var initDevices = sync.OnceFunc(func() {
}
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
accels = append(accels, d)
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
gpus = append(gpus, d)
}
@@ -471,9 +470,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
// Mimic llama runner logs summarizing layers and memory
gpuLayers := 0
for layer := range maps.Values(b.layers) {
switch C.ggml_backend_dev_type(layer.d) {
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
gpuLayers++
}
}
@@ -482,8 +479,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
switch C.ggml_backend_dev_type(b.output) {
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
slog.Info("offloading output layer to CPU")
case C.GGML_BACKEND_DEVICE_TYPE_GPU,
C.GGML_BACKEND_DEVICE_TYPE_IGPU:
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
slog.Info("offloading output layer to GPU")
gpuLayers++
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
@@ -726,9 +722,6 @@ func (b *Backend) BackendDevices() []ml.DeviceInfo {
}
info.PCIID = fmt.Sprintf("%02x:%02x.%x", props.pci_bus_id, props.pci_device_id, props.pci_domain_id)
info.LibraryPath = ggml.LibPaths()
if props.numeric_id != nil {
info.FilteredID = C.GoString(props.numeric_id)
}
C.ggml_backend_dev_memory(dev, &props.memory_free, &props.memory_total)
info.TotalMemory = (uint64)(props.memory_total)

View File

@@ -20,14 +20,10 @@ include /src/ggml-cuda/vendors/
include /src/ggml-cuda/template-instances/
include /src/ggml-hip/
include /src/ggml-metal/
include src/ggml-vulkan/
include src/ggml-vulkan/vulkan-shaders
include CMakeLists.txt
include *.[chm]
include *.cpp
include *.cu
include *.cuh
include *.metal
include *.comp
include *.glsl
hide *

View File

@@ -178,8 +178,6 @@ extern "C" {
int pci_device_id;
int pci_domain_id;
const char *library;
// number with which the devices are accessed (Vulkan)
const char *numeric_id;
};
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);

View File

@@ -118,18 +118,6 @@ static dl_handle * dl_load_library(const fs::path & path) {
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
if (!handle) {
DWORD error_code = GetLastError();
std::string msg;
LPSTR lpMsgBuf = NULL;
DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
if (bufLen) {
msg = lpMsgBuf;
LocalFree(lpMsgBuf);
GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
}
}
SetErrorMode(old_mode);

View File

@@ -340,15 +340,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
turing_devices_without_mma.push_back({ id, device_name });
}
// Temporary performance fix:
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
// TODO: Check for future drivers the default scheduling strategy and
// remove this call again when cudaDeviceScheduleSpin is default.
if (prop.major == 12 && prop.minor == 1) {
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
}
#endif // defined(GGML_USE_HIP)
}
@@ -4159,6 +4150,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
if (!initialized) {
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
int driverVersion = 0;
CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
@@ -4176,9 +4168,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->major = prop.major;
dev_ctx->minor = prop.minor;
if (driverVersion == 0) {
CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
}
dev_ctx->driver_major = driverVersion / 1000;
dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10;
dev_ctx->integrated = prop.integrated;

View File

@@ -1,211 +0,0 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
cmake_policy(SET CMP0116 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED)
function(detect_host_compiler)
if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
else()
find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH)
find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH)
endif()
set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE)
set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE)
endfunction()
# Function to test shader extension support
# Parameters:
# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product")
# TEST_SHADER_FILE - Path to the test shader file
# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result
function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE)
execute_process(
COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error
)
if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*")
message(STATUS "${EXTENSION_NAME} not supported by glslc")
set(${RESULT_VARIABLE} OFF PARENT_SCOPE)
else()
message(STATUS "${EXTENSION_NAME} supported by glslc")
set(${RESULT_VARIABLE} ON PARENT_SCOPE)
add_compile_definitions(${RESULT_VARIABLE})
# Ensure the extension support is forwarded to vulkan-shaders-gen
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE)
endif()
endfunction()
if (Vulkan_FOUND)
message(STATUS "Vulkan found")
ggml_add_backend_library(ggml-vulkan
ggml-vulkan.cpp
../../include/ggml-vulkan.h
)
set(VULKAN_SHADER_GEN_CMAKE_ARGS "")
# Test all shader extensions
test_shader_extension_support(
"GL_KHR_cooperative_matrix"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp"
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_NV_cooperative_matrix2"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp"
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_integer_dot_product"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_bfloat16"
"${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp"
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
)
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
# Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
endif()
if (GGML_VULKAN_CHECK_RESULTS)
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
endif()
if (GGML_VULKAN_DEBUG)
add_compile_definitions(GGML_VULKAN_DEBUG)
endif()
if (GGML_VULKAN_MEMORY_DEBUG)
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
endif()
if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON)
endif()
if (GGML_VULKAN_VALIDATE)
add_compile_definitions(GGML_VULKAN_VALIDATE)
endif()
if (GGML_VULKAN_RUN_TESTS)
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
endif()
# Set up toolchain for host compilation whether cross-compiling or not
if (CMAKE_CROSSCOMPILING)
if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN)
set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN})
else()
detect_host_compiler()
if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER)
message(FATAL_ERROR "Host compiler not found")
else()
message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}")
endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY)
set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake)
endif()
else()
# For non-cross-compiling, use empty toolchain (use host compiler)
set(HOST_CMAKE_TOOLCHAIN_FILE "")
endif()
include(ExternalProject)
if (CMAKE_CROSSCOMPILING)
list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE})
message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}")
endif()
ExternalProject_Add(
vulkan-shaders-gen
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$<CONFIG>
-DCMAKE_INSTALL_BINDIR=.
-DCMAKE_BUILD_TYPE=$<CONFIG>
${VULKAN_SHADER_GEN_CMAKE_ARGS}
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $<CONFIG>
BUILD_ALWAYS TRUE
# NOTE: When DESTDIR is set using Makefile generators and
# "make install" triggers the build step, vulkan-shaders-gen
# would be installed into the DESTDIR prefix, so it is unset
# to ensure that does not happen.
INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR
${CMAKE_COMMAND} --install . --config $<CONFIG>
)
set (_ggml_vk_host_suffix $<IF:$<STREQUAL:${CMAKE_HOST_SYSTEM_NAME},Windows>,.exe,>)
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$<CONFIG>")
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp")
# Because external projects do not provide source-level tracking,
# the vulkan-shaders-gen sources need to be explicitly added to
# ensure that changes will cascade into shader re-generation.
file(GLOB _ggml_vk_shaders_gen_sources
CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp"
"${_ggml_vk_input_dir}/*.h")
add_custom_command(
OUTPUT ${_ggml_vk_header}
COMMAND ${_ggml_vk_genshaders_cmd}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
DEPENDS ${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders header"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})
foreach (file_full ${_ggml_vk_shader_files})
get_filename_component(file ${file_full} NAME)
set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp")
add_custom_command(
OUTPUT ${_ggml_vk_target_cpp}
DEPFILE ${_ggml_vk_target_cpp}.d
COMMAND ${_ggml_vk_genshaders_cmd}
--glslc ${Vulkan_GLSLC_EXECUTABLE}
--source ${file_full}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
--target-cpp ${_ggml_vk_target_cpp}
DEPENDS ${file_full}
${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
COMMENT "Generate vulkan shaders for ${file}"
)
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp})
endforeach()
else()
message(WARNING "Vulkan not found")
endif()

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +0,0 @@
cmake_minimum_required(VERSION 3.19)
project("vulkan-shaders-gen" C CXX)
find_package (Threads REQUIRED)
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat glslc support")
endif()
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
message(STATUS "Enabling coopmat2 glslc support")
endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
message(STATUS "Enabling dot glslc support")
endif()
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
message(STATUS "Enabling bfloat16 glslc support")
endif()
if (GGML_VULKAN_SHADER_DEBUG_INFO)
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
message(STATUS "Enabling shader debug info")
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)

View File

@@ -1,29 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
}
}

View File

@@ -1,69 +0,0 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#if ADD_RMS
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
#if ADD_RMS
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
shared FLOAT_TYPE sumsh[num_threads];
#endif
void main() {
uint idx = get_idx();
uint orig_idx = idx;
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
FLOAT_TYPE sum_sq = 0;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
sum_sq += sum*sum;
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
idx += num_threads;
}
#if ADD_RMS
if (p.param3 != 0) {
// reduce the sum within each subgroup, then across subgroups
const uint NumSubgroups = num_threads / gl_SubgroupSize;
sum_sq = subgroupAdd(sum_sq);
if (gl_SubgroupInvocationID == 0) {
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
sum_sq += sumsh[gl_SubgroupID + s];
sumsh[gl_SubgroupID] = sum_sq;
}
barrier();
}
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
}
}
#endif
}

View File

@@ -1,42 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint ne0;
uint ne1;
uint s01;
uint s02;
uint s11;
uint s21;
} p;
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) readonly buffer Z {int32_t data_c[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i1 = gl_WorkGroupID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i11 = data_c[i1 + i2 * p.s21];
const uint s1 = p.ne0;
const uint s2 = p.ne0 * p.ne1;
const uint d0 = i1 * s1 + i2 * s2;
const uint a0 = i1 * p.s01 + i2 * p.s02;
const uint b0 = i11 * p.s11;
for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
}
}

View File

@@ -1,60 +0,0 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define FLT_MAX 3.402823466e+38F
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
shared uint tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
if (row >= p.KY) {
return;
}
A_TYPE amax = -FLT_MAX;
uint acol = col;
if (col < p.KX) {
amax = data_a[row*p.KX + col];
}
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
A_TYPE val = data_a[row*p.KX + i];
if (val > amax) {
amax = val;
acol = i;
}
}
tmp[col] = acol;
tmpmax[col] = amax;
barrier();
[[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
if (col < s && col + s < p.KX) {
if (tmpmax[col] < tmpmax[col + s]) {
tmpmax[col] = tmpmax[col + s];
tmp[col] = tmp[col + s];
}
}
barrier();
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
}
}

View File

@@ -1,79 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
#define ASC 0
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint order;
} p;
shared int dst_row[BLOCK_SIZE];
shared A_TYPE a_sh[BLOCK_SIZE];
void swap(uint idx0, uint idx1) {
int tmp = dst_row[idx0];
dst_row[idx0] = dst_row[idx1];
dst_row[idx1] = tmp;
}
void argsort(bool needs_bounds_check) {
// bitonic sort
const int col = int(gl_LocalInvocationID.x);
const uint row = gl_WorkGroupID.y;
const uint row_offset = row * p.ncols;
// initialize indices
dst_row[col] = col;
a_sh[col] = data_a[row_offset + col];
barrier();
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
uint num_inner_loop_iters = outer_idx + 1;
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
const int ixj = int(col ^ j);
int idx_0 = (col & k) == 0 ? col : ixj;
int idx_1 = (col & k) == 0 ? ixj : col;
int sh_idx_0 = dst_row[idx_0];
int sh_idx_1 = dst_row[idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
swap(idx_0, idx_1);
}
barrier();
}
}
if (col < p.ncols) {
if (p.order == ASC) {
data_d[row_offset + col] = dst_row[col];
} else {
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
}
}
}
void main() {
if (p.ncols == BLOCK_SIZE) {
argsort(false);
} else {
argsort(true);
}
}

View File

@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
}

View File

@@ -1,41 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const int dim = p.param3;
if (idx >= p.ne) {
return;
}
const uint i3 = idx / (p.ne22*p.ne21*p.ne20);
const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20;
const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20);
const uint i2_offset = i2*p.ne21*p.ne20;
const uint i1 = (idx - i3_offset - i2_offset) / p.ne20;
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20;
uint o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03));
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10;
const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20;
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
#else
if (is_src0) {
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
} else {
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
}
#endif
}

View File

@@ -1,49 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : require
const uint num_threads = 128;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 4;
// fast path for when all four iterations are in-bounds
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
} else {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
#endif
idx += num_threads;
}
}
}

View File

@@ -1,105 +0,0 @@
#version 450
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint ne;
uint batches;
uint channels;
uint dst_w;
uint dst_h;
uint src_w;
uint src_h;
uint knl_w;
uint knl_h;
int stride_x;
int stride_y;
int pad_x;
int pad_y;
int dilation_x;
int dilation_y;
} p;
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
uint i0 = idx / p.dst_w;
uint dst_x = idx - i0 * p.dst_w;
uint i1 = i0 / p.dst_h;
uint dst_y = i0 - i1 * p.dst_h;
uint n = i1 / p.channels;
uint c = i1 - n * p.channels;
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
uint knl_i = c * p.knl_h * p.knl_w;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
sum = fma(v, k, sum);
}
}
return sum;
}
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
uint i0 = idx / p.channels;
uint c = idx - i0 * p.channels;
uint i1 = i0 / p.dst_w;
uint dst_x = i0 - i1 * p.dst_w;
uint n = i1 / p.dst_h;
uint dst_y = i1 - n * p.dst_h;
uint src_i = n * p.channels * p.src_h * p.src_w;
uint src_row = p.src_w * p.channels;
uint knl_row = p.knl_w * p.channels;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
sum = fma(v, k, sum);
}
}
return sum;
}
void main() {
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
FLOAT_TYPE result =
#ifdef WHCN
conv_2d_dw_whcn(idx);
#else
conv_2d_dw_cwhn(idx);
#endif
dst_data[idx] = D_TYPE(result);
}

View File

@@ -1,349 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#ifdef COOPMAT2
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_KHR_memory_scope_semantics : enable
#endif
#ifdef USE_COLLECTIVES
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
A_TYPE knl_data[];
}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
layout(binding = 1) readonly buffer B {
B_TYPE src_data[];
}; // src1 - input: [W, H, Cin, N] -- channel_first format
layout(binding = 2) writeonly buffer D {
D_TYPE dst_data[];
}; // dst - result: [OW, OH, Cout, N]
layout(push_constant) uniform parameter {
// I/O channels, batch size
uint32_t Cout;
uint32_t Cin;
uint32_t N;
// Tensor spatial sizes: kernel, input, output
uint32_t KW;
uint32_t KH;
uint32_t W;
uint32_t H;
uint32_t OW;
uint32_t OH;
// Parameters: stride, padding, dilation - 0=y, 1=x
uint32_t s0;
uint32_t s1;
uint32_t p0;
uint32_t p1;
uint32_t d0;
uint32_t d1;
// Strides in elements
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb1;
uint32_t nb2;
uint32_t nb3;
// fastdiv helper values
uint32_t KWmp; uint32_t KWL;
uint32_t KWKHmp; uint32_t KWKHL;
uint32_t OWmp; uint32_t OWL;
uint32_t OWOHmp; uint32_t OWOHL;
#ifdef TRANSPOSE
uint32_t s0mp; uint32_t s0L;
uint32_t s1mp; uint32_t s1L;
#endif
}
p;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
// Blocktile sizes
layout(constant_id = 1) const uint BS_K = 128;
layout(constant_id = 2) const uint BS_CRS = 16;
layout(constant_id = 3) const uint BS_NPQ = 128;
// Thread-tile sizes
layout(constant_id = 4) const uint TS_K = 8;
layout(constant_id = 5) const uint use_collectives = 1;
layout(constant_id = 6) const uint SHMEM_PAD = 4;
uint32_t tid = gl_LocalInvocationID.x;
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
uint splitWork(uint work_size, uint block_size) {
return (block_size + work_size - 1) / block_size;
}
uint32_t K = p.Cout;
uint32_t CRS = p.Cin * p.KH * p.KW;
uint32_t NPQ = p.N * p.OH * p.OW;
uint32_t n_elems_out = K * NPQ;
// Number of blocktiles per input
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
#ifdef COOPMAT2
#define SHMEM_TYPE float16_t
#else
#define SHMEM_TYPE float
#endif
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
const uint32_t Ash_numel = BS_K * BS_CRS;
const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
const uint32_t Ash_len = BS_K * Ash_stride;
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
// Threadtile sizes
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
// Number of threadtiles per blocktile
const uint32_t NT_K = BS_K / TS_K;
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
/*
Compute
KxCRS @ CRSxNPQ = K x NPQ
K=Cout
C=Cin
R,S=KH,KW
P,Q=OH,OW
*/
uint32_t B_idx_K = gl_WorkGroupID.x;
uint32_t B_idx_NPQ = gl_WorkGroupID.y;
uint32_t T_y = tid / NT_NPQ;
uint32_t T_x = tid % NT_NPQ;
uint32_t Ar = tid / BS_CRS;
uint32_t Ac = tid % BS_CRS;
const uint32_t ArpWg = WG_SIZE / BS_CRS;
uint32_t Br = tid / BS_NPQ;
uint32_t Bc = tid % BS_NPQ;
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
#ifdef COOPMAT2
#define ACC_TYPE float16_t
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
{
uint32_t K_idx = B_idx_K * BS_K + r;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = D_TYPE(elem);
}
return elem;
}
#endif
void main() {
#ifdef COOPMAT2
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
#else
float regC[TS_K][TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = 0.0;
}
}
#endif
/* Advance block in CRS dim */
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
uint32_t CRS_idx_a;
uint32_t Cin_idx_a;
uint32_t KH_idx_a;
uint32_t KW_idx_a;
#ifdef USE_COLLECTIVES
uint32_t cached_CRS_idx;
uint32_t cached_Cin_idx;
uint32_t cached_KH_idx;
uint32_t cached_KW_idx;
if (use_collectives == 1) {
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
} else {
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
}
#else
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
#endif
/* Load kernel to A_block: (BS_K x BS_CRS)*/
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
uint32_t B_ly = r_offset + Ar;
uint32_t B_lx = Ac;
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
#ifdef TRANSPOSE
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
#else
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
#endif
float val = knl_data[knl_idx];
if (K_idx >= K || CRS_idx_a >= CRS) {
val = 0.0;
}
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
}
/* Load input to B_block: (BS_CRS x BS_NPQ) */
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
uint32_t B_ly = r_offset + Br; /* Row index of B block */
uint32_t B_lx = Bc;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW;
uint32_t CRS_idx_b;
uint32_t Cin_idx_b;
uint32_t KH_idx_b;
uint32_t KW_idx_b;
#ifdef USE_COLLECTIVES
if (use_collectives == 1) {
CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br);
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
} else {
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
}
#else
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
#endif
#ifdef TRANSPOSE
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
#else
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
#endif
uint32_t src_idx =
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
float val = src_data[src_idx];
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
#ifdef TRANSPOSE
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
#endif
) {
val = 0.0;
}
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
}
barrier();
#ifdef COOPMAT2
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
matC = coopMatMulAdd(matA, matB, matC);
#else
if (T_y * TS_K < K) {
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
float regA[TS_K];
float regB[TS_NPQ];
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
}
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
}
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
}
}
}
}
#endif
barrier();
}
/* Save C* */
#ifdef COOPMAT2
coopMatPerElementNV(matC, matC, perElemOpStore);
#else
if (T_y * TS_K < K) {
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
if (K_idx < K && NPQ_idx < NPQ) {
dst_data[dst_idx] = regC[T_ly][T_lx];
}
}
}
}
#endif
}

View File

@@ -1,98 +0,0 @@
#version 450
#include "types.glsl"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout]
layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in;
layout (push_constant) uniform parameter {
uint32_t Cout;
uint32_t Cin;
uint32_t K;
uint32_t L;
uint32_t KL;
uint32_t nb01;
uint32_t nb02;
uint32_t nb11;
uint32_t nb1;
int32_t s0;
} p;
uint32_t Cout_idx = gl_WorkGroupID.x;
const uint32_t bs = gl_WorkGroupSize.x;
uint32_t tid = gl_LocalInvocationID.x;
// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K.
uint32_t tmp_len = bs*p.s0+p.K;
shared D_TYPE tmp[4096];
uint splitWork(uint workSize){
return (bs + workSize -1) / bs;
}
void main(){
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx < tmp_len){
tmp[idx] = 0.0;
}
}
uint32_t L_blocks = splitWork(p.L);
for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){
if(L_block_id > 0){
barrier();
// Shift values in tmp to the current processing window
for(int i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
if(idx >= bs*p.s0 && idx < tmp_len){
tmp[idx-bs*p.s0] = tmp[idx];
tmp[idx] = 0.0;
}else if(idx >= p.K && idx < bs*p.s0){
tmp[idx] = 0.0;
}
}
}
barrier();
// Save contributions of the block to tmp
uint32_t L_idx = L_block_id*bs + tid;
for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){
D_TYPE dp = 0.0;
for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){
A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02];
if(L_idx < p.L){
B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11];
dp = fma(elemKrn, elemInp, dp);
}
}
tmp[tid*p.s0 + K_idx] += dp;
barrier();
}
// Save the computed values except the last block that can have different size
uint32_t KLb_idx = L_block_id*bs*p.s0;
if(L_block_id < L_blocks-1){
for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){
uint32_t sh_idx = p.s0*tid+s0_idx;
uint32_t KL_idx = KLb_idx+sh_idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx];
}
}
}
}
for(uint32_t i = 0; i < splitWork(tmp_len); i++){
uint32_t idx = i*bs+tid;
uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx;
if(KL_idx < p.KL){
data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx];
}
}
}

View File

@@ -1,23 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
#endif
}

View File

@@ -1,51 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
#include "dequant_funcs.glsl"
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;
}
#endif
const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = get_doffset() + dst_idx(idx);
uint src_idx = src0_idx_quant(idx, QUANT_K);
const uint a_offset = 0;
const uint ib = src_idx;
const vec2 dm = get_dm(ib, a_offset);
[[unroll]] for (int j = 0; j < QUANT_K; j += 4) {
vec4 v = dequantize4(ib, j / QUANT_R, a_offset);
v = v * dm.x + vec4(dm.y);
#if QUANT_R == 2
data_d[dst_idx + j/2 + 0] = v[0];
data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1];
data_d[dst_idx + j/2 + 1] = v[2];
data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3];
#else
data_d[dst_idx + j + 0] = v[0];
data_d[dst_idx + j + 1] = v[1];
data_d[dst_idx + j + 2] = v[2];
data_d[dst_idx + j + 3] = v[3];
#endif
}
}

View File

@@ -1,296 +0,0 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#if defined(SET_ROWS) && QUANT_K == 1
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 512;
#else
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
const uint BLOCK_SIZE = 32;
#endif
layout (binding = 0) readonly buffer S {float data_s[];};
#if defined(SET_ROWS)
#include "generic_binary_head.glsl"
layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
#if B_SIZE == 64
#define DATA_I_SWIZZLE .x
#else
#define DATA_I_SWIZZLE
#endif
#else
#include "generic_unary_head.glsl"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#endif
#if defined(DATA_A_Q4_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id;
const uint xi0 = min(15, int(x0 + 8.5));
const uint xi1 = min(15, int(x1 + 8.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q4_1)
void quantize(uint dst_idx, uint src_idx)
{
float vmin = 1.0/0.0;
float vmax = -vmin;
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) {
const float v = data_s[src_idx + j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(vmin);
[[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - vmin)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id;
const uint xi0 = min(15, int(x0 + 0.5));
const uint xi1 = min(15, int(x1 + 0.5));
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
}
}
#endif
#if defined(DATA_A_Q5_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id;
const uint xi0 = min(31, int(x0 + 16.5));
const uint xi1 = min(31, int(x1 + 16.5));
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2);
}
data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF);
data_q[dst_idx].qh[1] = uint16_t(qh >> 16);
}
#endif
#if defined(DATA_A_Q5_1)
void quantize(uint dst_idx, uint src_idx)
{
float min = data_s[src_idx + 0];
float max = min;
[[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) {
const float v = data_s[src_idx + j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = (d != 0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
data_q[dst_idx].m = float16_t(min);
uint32_t qh = 0;
[[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) {
const float x0 = (data_s[src_idx + 0 + j] - min)*id;
const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id;
const uint xi0 = uint(x0 + 0.5);
const uint xi1 = uint(x1 + 0.5);
data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4));
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2);
}
data_q[dst_idx].qh = qh;
}
#endif
#if defined(DATA_A_Q8_0)
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0; // absolute max
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) {
const float v = data_s[src_idx + j];
amax = max(amax, abs(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = (d != 0.0) ? 1.0/d : 0.0;
data_q[dst_idx].d = float16_t(d);
[[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) {
const float x0 = data_s[src_idx + j]*id;
data_q[dst_idx].qs[j] = int8_t(round(x0));
}
}
#endif
#if defined(DATA_A_IQ4_NL)
uint best_index(float x) {
if (x <= kvalues_iq4nl[0]) return 0;
if (x >= kvalues_iq4nl[15]) return 15;
int ml = 0, mu = 15;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav;
}
return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu;
}
void quantize(uint dst_idx, uint src_idx)
{
float amax = 0.0;
float vmax = 0.0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) {
const float v = data_s[src_idx + j];
if (amax < abs(v)) {
amax = abs(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = (d != 0.0) ? 1.0/d : 0.0;
float sumqx = 0, sumq2 = 0;
[[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) {
const float x0 = data_s[src_idx + 0 + j]*id;
const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id;
const uint xi0 = best_index(x0);
const uint xi1 = best_index(x1);
data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4));
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j];
const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(data_s[src_idx]);
}
#endif
#if defined(DATA_A_BF16)
void quantize(uint dst_idx, uint src_idx)
{
data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx]));
}
#endif
#if defined(SET_ROWS)
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
uint i12 = fastmod(i03, p.ne12);
uint i11 = fastmod(i02, p.ne11);
uint i10 = i01;
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
quantize(dst_idx, src0_idx);
}
#else
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K;
if (idx >= p.ne) {
return;
}
uint dst_idx = dst_idx_quant(idx, QUANT_K);
uint src_idx = get_aoffset() + src0_idx(idx);
quantize(dst_idx, src_idx);
}
#endif

View File

@@ -1,17 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
}

View File

@@ -1,31 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.glsl"
#include "generic_head.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
layout (binding = 2) buffer D {D_TYPE data_d[];};
const uint CHUNK_SIZE = 512;
void main() {
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
const uint col = gl_LocalInvocationID.x;
uint count = 0;
[[unroll]]
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
const uint idx = base + i + col;
if (idx >= p.KX) {
break;
}
count += uint(data_a[idx] == data_b[idx]);
}
atomicAdd(data_d[0], D_TYPE(count));
}

View File

@@ -1,20 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_GlobalInvocationID.x * 16;
if (i >= p.nel) {
return;
}
[[unroll]] for (uint l = 0; l < 16; l++) {
data_b[i + l] = D_TYPE(data_a[i + l]);
}
}

View File

@@ -1,616 +0,0 @@
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
#endif
#if defined(DATA_A_F16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
}
#endif
#if defined(DATA_A_BF16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
}
#endif
#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2(vui & 0xF, vui >> 4) - 8.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
}
#endif
#if defined(DATA_A_Q4_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(vui & 0xF, vui >> 4);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12);
}
#endif
#if defined(DATA_A_Q5_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f);
}
#endif
#if defined(DATA_A_Q5_1)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = data_a[a_offset + ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y);
}
#endif
#if defined(DATA_A_Q8_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_IQ1_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const int i8 = int(iqs % 8);
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qs = data_a[a_offset + ib].qs[ib8];
const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
// Signed bitfield extract.
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
return dl * (vec2(gvec) + delta);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const int i8 = int(iqs % 8);
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qs = data_a[a_offset + ib].qs[ib8];
const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1;
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
// Signed bitfield extract.
const ivec4 gvec = ivec4(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2),
bitfieldExtract(grid, 2 * (i8 + 2), 2),
bitfieldExtract(grid, 2 * (i8 + 3), 2)
);
return dl * (vec4(gvec) + delta);
}
#endif
#if defined(DATA_A_IQ1_M)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib8 = iqs / 8;
const uint ib16 = iqs / 16;
const int i8 = int(iqs % 8);
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
// Signed bitfield extract.
const ivec2 gvec = ivec2(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2)
);
return dl * (vec2(gvec) + delta);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib8 = iqs / 8;
const uint ib16 = iqs / 16;
const int i8 = int(iqs % 8);
const uint sc = data_a[a_offset + ib].scales[iqs / 64];
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
// Signed bitfield extract.
const ivec4 gvec = ivec4(
bitfieldExtract(grid, 2 * (i8), 2),
bitfieldExtract(grid, 2 * (i8 + 1), 2),
bitfieldExtract(grid, 2 * (i8 + 2), 2),
bitfieldExtract(grid, 2 * (i8 + 3), 2)
);
return dl * (vec4(gvec) + delta);
}
#endif
#if defined(DATA_A_IQ2_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = (iqs / 8) % 4;
const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2],
data_a_packed16[a_offset + ib].qs[4 * ib32 + 3]));
const float db = 0.25 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[iqs / 8];
const float db = 0.25 * (0.5 + scale);
const uint sign7 = qs >> 9;
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ2_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid[iqs % 4] * (sign0 ? -1.0 : 1.0),
grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint ib8 = iqs / 8;
const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf;
const uint qs = data_a[a_offset + ib].qs[ib8];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8);
const float db = 0.25 * (0.5 + scale);
const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_XXS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
// Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale)
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4)));
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
return db * vec2(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint is = QUANT_K / 4 + 4 * ib32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2],
data_a_packed16[a_offset + ib].qs[is / 2 + 1]));
const float db = 0.5 * (0.5 + (signs >> 28));
const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7);
// Add parity bit
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint sign = sign8 >> (iqs % 8);
const u8vec4 grid = unpack8(iq3xxs_grid[qs]);
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
return db * vec4(
grid.x * (sign0 ? -1.0 : 1.0),
grid.y * (sign1 ? -1.0 : 1.0),
grid.z * (sign2 ? -1.0 : 1.0),
grid.w * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ3_S)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint qs = data_a[a_offset + ib].qs[iqs / 4];
const uint qh = data_a[a_offset + ib].qh[iqs / 32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[iqs / 64];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4));
return db * vec2(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0)
);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib4 = iqs / 4;
const uint ib32 = iqs / 32;
const uint qs = data_a[a_offset + ib].qs[ib4];
const uint qh = data_a[a_offset + ib].qh[ib32];
const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8);
const uint scale = data_a[a_offset + ib].scales[ib32 / 2];
bool sign0 = (sign & 1) != 0;
bool sign1 = (sign & 2) != 0;
bool sign2 = (sign & 4) != 0;
bool sign3 = (sign & 8) != 0;
const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf);
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4));
return db * vec4(
int(grid & 0xFF) * (sign0 ? -1.0 : 1.0),
int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0),
int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0),
int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0)
);
}
#endif
#if defined(DATA_A_IQ4_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec4 qs = u8vec4(
data_a[a_offset + ib].qs[iq + 0],
data_a[a_offset + ib].qs[iq + 1],
data_a[a_offset + ib].qs[iq + 2],
data_a[a_offset + ib].qs[iq + 3]
);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec4(
kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
}
#endif
#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
vec2 v1 = dequantize(ib, iqs + 1, a_offset);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}
#endif
#if defined(DATA_A_IQ1_M)
vec2 get_dm(uint ib, uint a_offset) {
const uint16_t[4] scales = data_a[a_offset + ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
return vec2(d, 0);
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
}
#endif
#if defined(DATA_A_Q2_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
const vec2 d = vec2(data_a[a_offset + ib].d);
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q3_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 64; // 0,1
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
const uint j = (iqs % 64) / 4; // 0..3
const uint is = iqs / 8; // 0..15
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
| (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
const float dl = float(data_a[a_offset + ib].d) * float(us - 32);
return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q4_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m),
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q5_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
const uint8_t hm = uint8_t(1 << (iqs / 16));
const vec2 loadd = vec2(data_a[a_offset + ib].d);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint scidxshift1 = (is < 4) ? 0 : 2;
const uint mbidx0 = is + 4;
const uint mbidx1 = (is < 4) ? is + 4 : is;
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
const uint mbidxshift0 = (is < 4) ? 0 : 4;
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
const uint mbidxshift1 = (is < 4) ? 0 : 2;
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const float d = loadd.x * sc;
const float m = -loadd.y * mbyte;
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif
#if defined(DATA_A_Q6_K)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
iqs /= 2;
const uint n = iqs / 64; // 0,1
const uint b = (iqs % 64) / 32; // 0,1
const uint is_b = (iqs % 16) / 8; // 0,1
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const uint is = 8 * n + qhshift + is_b; // 0..15
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]);
return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
}
#endif

View File

@@ -1,720 +0,0 @@
#include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
qs >>= shift;
qs &= 0x0F0F;
qs = unpack8(qs)[idx & 1];
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
block_q5_0 block;
};
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
block_q5_1 block;
};
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint uint_qh = bl.block.qh;
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs | qh) * d + m;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;
// Load 16b and select the byte for this element
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
float16_t ret = float16_t(qs) * d;
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
block_q2_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
block_q2_K_packed16 block;
};
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qs = (qs >> qsshift) & 0x0303;
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
block_q3_K block;
};
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;
const uint n = iqs / 128; // 0,1
const uint qsi = n * 32 + (iqs % 32); // 0..63
const uint hmi = (iqs % 32); // 0..31
const uint j = (iqs % 128) / 8; // 0..15
const uint is = iqs / 16; // 0..15
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
uint32_t scaleidx1 = is + 8 - (is/4)*4;
uint32_t scaleidx1shift = (is/4)*2;
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
const float16_t dl = bl.block.d * float16_t(us - 32);
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
return ret;
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
block_q4_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
block_q4_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
block_q4_K_packed128 block;
};
#if defined(IS_MUL_MM2)
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
// into shared memory and then process the whole tile using those scales.
// There is a fetch function that loads into private variables and then a store
// function that stores into shared memory.
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
// the part that fetches from the structure (which has a different block layout).
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
const uint shAscales_stride = (BM + 2);
// 1 scale per 32 elements -> 8 scales per block, per row
shared vec2 shAscales[8 * shAscales_stride];
uvec4 row_v;
#endif
#if defined(DATA_A_Q4_K)
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q4_k_packed128[block_index].q4k[0];
}
}
#endif
#if defined(DATA_A_Q5_K)
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q5_k_packed128[block_index].q5k[0];
}
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
void store_scalesQ4_K(uint tid)
{
barrier();
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
uint is = idx + is_start;
uvec4 v = row_v;
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
}
barrier();
}
#endif
#endif
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q4k[0];
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
#endif
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float ret = d * float(qs) - m;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
block_q5_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
block_q5_K_packed16 block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
block_q5_K_packed128 block;
};
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
#endif
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = ((qh >> is) & 0x101) << 4;
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs | qh)[idx & 1];
float ret = d * float(qs) - m;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
block_q6_K block;
};
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
block_q6_K_packed16 block;
};
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
const uint idx = coordInBlock[1];
const uint b = (idx & 0x40) >> 6; // 0,1
const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
const uint is = (idx & 0xF0) >> 4; // 0..15
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
ql = (ql >> (b * 4)) & 0x0F0F;
uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
qh = ((qh >> qhshift) & 0x0303) << 4;
int q = unpack8(ql | qh)[idx & 1];
float16_t ret = dscale * float16_t(q - 32);
return ret;
}
#if defined(DATA_A_IQ1_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
block_iq1_s block;
};
float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5;
const uint ib8 = (idx & 0xF8) >> 3;
const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
return ret;
}
#endif
#if defined(DATA_A_IQ1_M)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
block_iq1_m block;
};
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
block_iq1_m_packed64 block;
};
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const uint idx = coordInBlock[1];
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
const uint ib8 = (idx & 0xF8) >> 3;
const uint ib16 = (idx & 0xF0) >> 4;
const int i8 = int(idx % 8);
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
return ret;
}
#endif
#if defined(DATA_A_IQ2_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
block_iq2_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
block_iq2_xxs_packed16 block;
};
float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0x18) >> 3; // 0..3
const uint iqs = 8 * ib32 + ib8;
const uint qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
sign |= bitCount(sign) << 7;
uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
#if defined(DATA_A_IQ2_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
block_iq2_xs block;
};
float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint is = (idx & 0xE0) >> 5; // 0..8
const uint sshift = (idx & 0x10) >> 2; // 0,4
const uint iqs = (idx & 0xF8) >> 3; // 0..63
const uint16_t qs = bl.block.qs[iqs];
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
uint sign = uint(qs >> 9);
sign |= bitCount(sign) << 7;
uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
#if defined(DATA_A_IQ2_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
block_iq2_s block;
};
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
const uint qhshift = 2 * (ib8 % 4);
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib32];
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
const float d = float(bl.block.d);
const float db = d * 0.25 * (0.5 + scale);
const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ3_XXS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
block_iq3_xxs block;
};
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
block_iq3_xxs_packed16 block;
};
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
uint idx = coordInBlock[1];
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint signs = pack32(u16vec2(
bl16.block.qs[is/2+0],
bl16.block.qs[is/2+1]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ3_S)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
block_iq3_s block;
};
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint iqh = (idx & 0xE0) >> 5;
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint qh = bl.block.qh[iqh];
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
const uint scale = bl.block.scales[iqs / 16];
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ4_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
block_iq4_xs block;
};
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 16) >> 2;
const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
return ret;
}
#endif
#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
};
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
};
float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float d = e8m0_to_fp32(bl.block.e);
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
return ret;
}
#endif
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1
#elif defined(DATA_A_Q5_0)
#define dequantFuncA dequantFuncQ5_0
#elif defined(DATA_A_Q5_1)
#define dequantFuncA dequantFuncQ5_1
#elif defined(DATA_A_Q8_0)
#define dequantFuncA dequantFuncQ8_0
#elif defined(DATA_A_Q2_K)
#define dequantFuncA dequantFuncQ2_K
#elif defined(DATA_A_Q3_K)
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#define fetch_scales fetch_scalesQ4_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#define fetch_scales fetch_scalesQ5_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ1_S)
#define dequantFuncA dequantFuncIQ1_S
#elif defined(DATA_A_IQ1_M)
#define dequantFuncA dequantFuncIQ1_M
#elif defined(DATA_A_IQ2_XXS)
#define dequantFuncA dequantFuncIQ2_XXS
#elif defined(DATA_A_IQ2_XS)
#define dequantFuncA dequantFuncIQ2_XS
#elif defined(DATA_A_IQ2_S)
#define dequantFuncA dequantFuncIQ2_S
#elif defined(DATA_A_IQ3_XXS)
#define dequantFuncA dequantFuncIQ3_XXS
#elif defined(DATA_A_IQ3_S)
#define dequantFuncA dequantFuncIQ3_S
#elif defined(DATA_A_IQ4_XS)
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#endif

View File

@@ -1,13 +0,0 @@
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint M;
uint K;
uint stride_a;
uint stride_b;
uint nel;
} p;
#include "types.glsl"

View File

@@ -1,42 +0,0 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq1_m data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint ib64 = ib32 / 2;
const uint b_idx = 256 * ib + 32 * ib32;
const uint16_t[4] scales = data_a[ib].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint sc = data_a[ib].scales[ib64];
[[unroll]] for (int l = 0; l < 4; ++l) {
const uint ib16 = 2 * ib32 + l / 2;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1));
const uint qs = data_a[ib].qs[4 * ib32 + l];
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
[[unroll]] for (int j = 0; j < 8; ++j) {
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
}
}
}

View File

@@ -1,35 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq1_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
uint qh = data_a[ib].qh[ib32];
const float d = float(data_a[ib].d);
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qs = data_a[ib].qs[4 * ib32 + l];
const uint hi = bitfieldExtract(qh, 3 * int(l), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]);
[[unroll]] for (int j = 0; j < 8; ++j) {
data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta));
}
}
}

View File

@@ -1,44 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
uint qh = data_a[ib].qh[ib32];
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint qs = data_a[ib].qs[4 * ib32 + l];
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
qs |= (qh << (8 - 2 * l)) & 0x300;
const uvec2 grid = iq2s_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -1,43 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (32 values with 2 scales)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * ib32;
const float d = float(data_a[ib].d);
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4);
const vec2 db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
uint16_t qs = data_a[ib].qs[4 * ib32 + l];
const uint sign7 = qs >> 9;
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uvec2 grid = iq2xs_grid[qs & 511];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -1,49 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[8*is + 4],
data_a[ib].qs[8*is + 5],
data_a[ib].qs[8*is + 6],
data_a[ib].qs[8*is + 7]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.25;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
const uint qs = data_a[ib].qs[8 * is + l];
const uvec2 grid = iq2xxs_grid[qs];
const u8vec4 grid0 = unpack8(grid.x);
const u8vec4 grid1 = unpack8(grid.y);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -1,40 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale nibble.
// Each block contains 4 scale bytes (8 scales) for 256 output values.
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const float d = float(data_a[ib].d);
const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
uint qh = data_a[ib].qh[is];
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = 8 * is + l;
const uint qs = data_a[ib].qs[iqs];
const uint gidx = qs | ((qh << (8 - l)) & 256);
const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));
const u8vec4 grid = unpack8(iq3s_grid[gidx]);
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -1,51 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 scale block (32 values)
// 8 threads handle 1 superblock
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint is = gl_LocalInvocationID.x % 8;
const uint b_idx = 256 * ib + 32 * is;
const uint s_idx = QUANT_K / 4 + 4 * is;
const float d = float(data_a[ib].d);
uint signscale = pack32(u8vec4(
data_a[ib].qs[s_idx + 0],
data_a[ib].qs[s_idx + 1],
data_a[ib].qs[s_idx + 2],
data_a[ib].qs[s_idx + 3]
));
const float db = d * (0.5 + (signscale >> 28)) * 0.5;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
// Restore parity bit.
const uint sign8 = sign7 | (bitCount(sign7) << 7);
const uint qs0 = data_a[ib].qs[8 * is + 2 * l];
const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];
const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);
const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0));
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0));
}
}

View File

@@ -1,32 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = float(data_a[ib].d);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@@ -1,34 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (1 scale and 32 quantized values)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const float d = float(data_a[ib].d);
// Scales are 6 bits
const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
| (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
const float dl = d * (int(scale) - 32);
const uint b_idx = 256 * ib + 32 * ib32;
const uint q_idx = 16 * ib32;
[[unroll]] for (uint l = 0; l < 16; ++l) {
data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@@ -1,32 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
init_iq_shmem(gl_WorkGroupSize);
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View File

@@ -1,34 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint ip = tid / 32;
const uint il = tid - 32 * ip;
const uint is = 8 * ip + il / 16;
const uint y_idx = i * QUANT_K + 128 * ip + il;
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
}
}

View File

@@ -1,42 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = uint(gl_WorkGroupID.x * 256 + wgy);
if (i >= p.nel / QUANT_K) {
return;
}
const uint r = gl_LocalInvocationID.x / 4;
const uint tid = r / 2;
const uint is0 = r % 2;
const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4);
const uint n = tid / 4;
const uint j = tid - 4*n;
const uint8_t m = uint8_t(1 << (4*n + j));
const uint is = 8*n + 2*j + is0;
const uint shift = 2*j;
const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
(data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
const uint y_idx = i * QUANT_K + 128 * n + 32 * j;
const uint qs_idx = 32*n;
for (uint l = l0; l < l0 + 4; ++l) {
data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
}
}
}

View File

@@ -1,30 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint q_idx = 8*il;
const uint b_idx = 1024*i + 32*ir + q_idx;
const float d = float(data_a[ib].d);
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f));
data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f));
}
}

View File

@@ -1,32 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q4_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m);
}
}

View File

@@ -1,68 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint il = tid / 8;
const uint ir = tid % 8;
const uint is = 2 * il;
const uint n = 4;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
uint scidx0 = (is < 4) ? is : (is + 4);
uint scidx1 = (is < 4) ? is : (is - 4);
uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint scidxshift1 = (is < 4) ? 0 : 2;
uint mbidx0 = is + 4;
uint mbidx1 = (is < 4) ? is + 4 : is;
uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint mbidxshift0 = (is < 4) ? 0 : 4;
uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint mbidxshift1 = (is < 4) ? 0 : 2;
uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * mbyte;
scidx0 = (is < 4) ? is + 1 : (is + 5);
scidx1 = (is < 4) ? is + 1 : (is - 3);
scidxmask1 = (is < 4) ? 0x30 : 0xC0;
scidxshift1 = (is < 4) ? 0 : 2;
mbidx0 = is + 5;
mbidx1 = (is < 4) ? is + 5 : is + 1;
mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbidxshift0 = (is < 4) ? 0 : 4;
mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * mbyte;
[[unroll]] for (uint l = 0; l < n; ++l) {
data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1);
data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2);
}
}
}

View File

@@ -1,34 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f));
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f));
}
}

View File

@@ -1,35 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q5_1 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 8*il;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint qh = data_a[ib].qh;
const uint q_idx = 8*il;
[[unroll]] for (uint l = 0; l < 8; ++l) {
const uint iqs = q_idx + l;
const uint vui = uint(data_a[ib].qs[iqs]);
data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m);
data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m);
}
}

View File

@@ -1,70 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint ib = gl_WorkGroupID.x * 256 + wgy;
if (ib >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint il = tid / 16;
const uint ir = tid % 16;
const uint is = 2 * il;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
const uint qh_idx = 2 * ir;
uint scidx0 = (is < 4) ? is : (is + 4);
uint scidx1 = (is < 4) ? is : (is - 4);
uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint scidxshift1 = (is < 4) ? 0 : 2;
uint mbidx0 = is + 4;
uint mbidx1 = (is < 4) ? is + 4 : is;
uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint mbidxshift0 = (is < 4) ? 0 : 4;
uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint mbidxshift1 = (is < 4) ? 0 : 2;
uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d1 = dall * sc;
const FLOAT_TYPE m1 = dmin * mbyte;
scidx0 = (is < 4) ? is + 1 : (is + 5);
scidx1 = (is < 4) ? is + 1 : (is - 3);
scidxmask1 = (is < 4) ? 0x30 : 0xC0;
scidxshift1 = (is < 4) ? 0 : 2;
mbidx0 = is + 5;
mbidx1 = (is < 4) ? is + 5 : is + 1;
mbidxmask0 = (is < 4) ? 0xF : 0xF0;
mbidxshift0 = (is < 4) ? 0 : 4;
mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
mbidxshift1 = (is < 4) ? 0 : 2;
sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
const FLOAT_TYPE d2 = dall * sc;
const FLOAT_TYPE m2 = dmin * mbyte;
const uint8_t hm1 = uint8_t(1 << (2 * il ));
const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
}
}

View File

@@ -1,33 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
[[unroll]] for (uint wgy = 0; wgy < 256; wgy++) {
const uint i = gl_WorkGroupID.x * 256 + wgy;
if (i >= p.nel / QUANT_K) {
return;
}
const uint tid = gl_LocalInvocationID.x;
const uint ip = tid / 32;
const uint il = tid - 32 * ip;
const uint is = 8 * ip + il / 16;
const uint y_idx = i * QUANT_K + 128 * ip + il;
const uint ql_idx = 64 * ip + il;
const uint8_t qh = data_a[i].qh[32 * ip + il];
const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
}
}

View File

@@ -1,31 +0,0 @@
#version 450
#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_q8_0 data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
const uint tid = gl_LocalInvocationID.x % 64;
const uint il = tid/32;
const uint ir = tid%32;
const uint ib = 32*i + ir;
if (ib >= p.nel / 32) {
return;
}
const uint b_idx = 1024*i + 32*ir + 16*il;
const float d = float(data_a[ib].d);
const uint q_idx = 16*il;
[[unroll]] for (uint l = 0; l < 16; l += 2) {
data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]);
data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]);
}
}

View File

@@ -1,34 +0,0 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : enable
layout (push_constant) uniform parameter
{
uint ncols;
uint rows_per_channel;
uint n_past;
} p;
#include "types.glsl"
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint col = gl_GlobalInvocationID.y;
const uint row = gl_GlobalInvocationID.x;
if (col >= p.ncols) {
return;
}
const uint i = row*p.ncols + col;
if (col > p.n_past + row % p.rows_per_channel) {
data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000));
} else {
data_d[i] = D_TYPE(data_a[i]);
}
}

View File

@@ -1,27 +0,0 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}

View File

@@ -1,21 +0,0 @@
#version 450
#include "rte.glsl"
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(exp(float(data_a[i])));
}

View File

@@ -1,383 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t cols_per_iter = WorkGroupSize / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
shared FLOAT_TYPE tmpsh[WorkGroupSize];
shared vec4 tmpshv4[WorkGroupSize];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][HSK / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
vec4 Of[Br][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
}
float Lf[Br], Mf[Br];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
float slope[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = 1.0;
}
// ALiBi
if (p.max_bias > 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = 0.0;
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
}
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
}
}
}
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
}
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
}
}
}
barrier();
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
}
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
}
eMf[r] = exp(Moldf[r] - Mf[r]);
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowsumf[r] += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf;
}
}
}
barrier();
}
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float rowmaxf, eMf;
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
}
barrier();
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
}
barrier();
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += tmpshv4[tid + s];
tmpshv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpshv4[d_tid];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ms;
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
#endif
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@@ -1,202 +0,0 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t HSK = 32;
layout (constant_id = 4) const uint32_t HSV = 32;
layout (constant_id = 5) const uint32_t Clamp = 0;
layout (constant_id = 6) const uint32_t D_split = 16;
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
const uint32_t HSK_pad = (HSK + 15) & ~15;
const uint32_t HSV_pad = (HSV + 15) & ~15;
const bool KV_bounds_check = Clamp != 0;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask_n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
#define SINK_ENABLE_BIT (1<<24)
#define MASK_ENABLE_BIT (1<<16)
#define N_LOG2_MASK 0xFFFF
layout (binding = 4) readonly buffer S {float data_s[];};
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
} else {
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
if (binding_idx == BINDING_IDX_K) {
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
} else {
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
// Load the sink value, indexed by Q's dimension 2.
ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
return ACC_TYPE(data_s[h]);
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
{
N = p.N;
KV = p.KV;
i = gl_WorkGroupID.x;
split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
Tr = CEIL_DIV(N, Br);
start_j = split_k_index * p.split_kv / Bc;
end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
iq2 = gl_WorkGroupID.y * p.gqa_ratio;
iq3 = gl_WorkGroupID.z;
// broadcast factors
rk2 = p.neq2/p.nek2;
rk3 = p.neq3/p.nek3;
rv2 = p.neq2/p.nev2;
rv3 = p.neq3/p.nev3;
// k indices
ik3 = iq3 / rk3;
ik2 = iq2 / rk2;
// v indices
iv3 = iq3 / rv3;
iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
k_stride = p.nb11;
v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
}

View File

@@ -1,418 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
const uint32_t row_split = 4;
const uint32_t rows_per_thread = Br / row_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
const uint32_t MatBr = 16;
const uint32_t MatBc = 16;
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 Qf[Br * qstride];
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
shared ACC_TYPE sfsh[Bc * sfshstride];
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
shared f16vec4 ksh[Bc * kshstride];
shared float slope[Br];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
#define tile_row(r) (row_tid * rows_per_thread + (r))
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
if ((HSK % 16) != 0) {
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
if (i + tid < Br * qstride) {
Qf[i + tid] = f16vec4(0);
}
}
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
if (i + tid < Bc * kshstride) {
ksh[i + tid] = f16vec4(0);
}
}
barrier();
}
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t r = (idx + tid) / (HSK / 4);
if (r < Br && d < HSK / 4 &&
i * Br + r < N) {
Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
}
}
barrier();
ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPEV4(0.0);
}
}
float Lf[rows_per_thread], Mf[rows_per_thread];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
// ALiBi
if (p.max_bias > 0.0f) {
if (tid < Br) {
uint r = tid;
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
barrier();
} else {
if (tid < Br) {
uint r = tid;
slope[r] = 1.0;
}
barrier();
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
if (c < Bc && d < HSK / 4) {
f16vec4 K_Tf = f16vec4(0);
if (!KV_bounds_check || j * Bc + c < KV) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
#else
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
#endif
}
ksh[c * kshstride + d] = K_Tf;
}
}
barrier();
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
// This is written transposed in order to allow for N being 8 if implementations need it
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
}
uint coord = gl_SubgroupID * MatBc * sfshstride;
coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor);
barrier();
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) / Br;
uint32_t r = (idx + tid) % Br;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
}
}
barrier();
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
}
barrier();
}
float eMf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float rowmaxf = NEG_FLT_MAX_OVER_2;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
}
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf[r] = exp(Moldf - Mf[r]);
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
continue;
}
float Pf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
Lf[r] += Pf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
barrier();
}
// reduce across threads
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE M = Mf[r];
tmpsh[tid] = M;
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
M = max(M, tmpsh[tid ^ s]);
barrier();
tmpsh[tid] = M;
barrier();
}
rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
eMf[r] = exp(Moldf[r] - Mf[r]);
Lf[r] = eMf[r]*Lf[r];
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
FLOAT_TYPE L = Lf[r];
tmpsh[tid] = L;
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
L += tmpsh[tid ^ s];
barrier();
tmpsh[tid] = L;
barrier();
}
Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
Of[r][d] += tmpshv4[tid ^ s];
barrier();
tmpshv4[tid] = Of[r][d];
barrier();
}
Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
if (sink > Mf[r]) {
ms = exp(Mf[r] - sink);
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] *= ACC_TYPE(ms);
}
} else {
vs = exp(sink - Mf[r]);
}
Lf[r] = Lf[r]*ms + vs;
}
}
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
#endif
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (i * Br + tile_row(r) < N) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@@ -1,320 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
#extension GL_NV_cooperative_matrix2 : enable
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_EXT_null_initializer : enable
#include "types.glsl"
#include "dequant_funcs_cm2.glsl"
#include "flash_attn_base.glsl"
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return max(x, y);
}
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
// Replace matrix elements >= numRows or numCols with 'replace'
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
if (row >= numRows || col >= numCols) {
return replace;
}
return elem;
}
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
{
return exp(elem);
}
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
{
return max(elem0, elem1);
}
#if defined(BLOCK_SIZE)
#define DECODEFUNC , DEQUANTFUNC
#else
#define DECODEFUNC
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c < HSV) {
uint32_t offset = (iq2 + r) * HSV + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
init_indices();
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if defined(BLOCK_SIZE)
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
#endif
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK);
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV);
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
q_stride &= ~7;
#if !defined(BLOCK_SIZE)
k_stride &= ~7;
v_stride &= ~7;
#endif
m_stride &= ~7;
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
Qf16 *= float16_t(p.scale);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
// ALiBi
if (p.max_bias > 0.0f) {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
uint32_t m_offset = 0;
if (p.nem2 != 1 || p.nem3 != 1) {
m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
S = coopMatMulAdd(Qf16, K_T, S);
if (p.logit_softcap != 0.0f) {
[[unroll]]
for (int k = 0; k < S.length(); ++k) {
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
}
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
if (nem1_bounds_check) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
}
// Clear padding elements to -inf, so they don't contribute to rowmax
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
coopMatPerElementNV(M, rowmax, Max, Mold);
coopMatPerElementNV(P, S - M, Exp);
coopMatPerElementNV(eM, Mold - M, Exp);
// Clear padding elements to 0, so they don't contribute to rowsum
if (Clamp != 0 &&
((j + 1) * Bc > KV ||
(i + 1) * Br > N)) {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
}
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
// compute rowsum by multiplying by matrix of all ones.
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
rowsum = coopMatMulAdd(P_A, One, rowsum);
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
L = eM*L + rowsum;
// This is the "diagonal" matrix in the paper, but since we do componentwise
// multiply rather than matrix multiply it has the diagonal element smeared
// across the row
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
// resize L by using smear/reduce
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
// resize M by using smear/reduce
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
// O, Ldiag, Mr all have the same type so all element locations match
[[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) {
ACC_TYPE sink = S[i];
ACC_TYPE ms = ACC_TYPE(1.0f);
ACC_TYPE vs = ACC_TYPE(1.0f);
if (sink > Mr[i]) {
ms = exp(Mr[i] - sink);
O[i] *= ms;
} else {
vs = exp(sink - Mr[i]);
}
Ldiag[i] = Ldiag[i]*ms + vs;
}
}
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
}
O = Ldiag*O;
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
}
}

View File

@@ -1,120 +0,0 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) readonly buffer B {float data_s[];};
layout (binding = 2) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
uint sinks;
} p;
shared float tmpsh[BLOCK_SIZE];
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float m = data_a[m_offset + (k + tid) * lm_stride];
m_max = max(m_max, m);
}
// reduce across the workgroup
tmpsh[tid] = m_max;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
m_max = max(m_max, tmpsh[tid + s]);
tmpsh[tid] = m_max;
}
barrier();
}
m_max = tmpsh[0];
barrier();
// Compute L based on m_max
float L = 0;
for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) {
float l = data_a[l_offset + (k + tid) * lm_stride];
float m = data_a[m_offset + (k + tid) * lm_stride];
L += exp(m - m_max) * l;
}
// reduce across the workgroup
tmpsh[tid] = L;
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
L += tmpsh[tid + s];
tmpsh[tid] = L;
}
barrier();
}
L = tmpsh[0];
float sink;
if (p.sinks != 0) {
sink = data_s[n];
float ms = 1.0f;
float vs = 1.0f;
if (sink > m_max) {
ms = exp(m_max - sink);
} else {
vs = exp(sink - m_max);
}
L = L*ms + vs;
}
L = 1.0 / L;
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
// Scale and sum the O contributions based on m_max and store the result to memory
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
if (p.sinks != 0) {
if (sink > m_max) {
float ms = 1.0f;
ms = exp(m_max - sink);
O *= ms;
}
}
O *= L;
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
data_d[iq3 * D * N + D * n + d] = O;
}
}

View File

@@ -1,13 +0,0 @@
#version 450
#include "glu_head.glsl"
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
float op(float a, float b) {
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
}
#include "glu_main.glsl"

View File

@@ -1,27 +0,0 @@
#version 450
#include "glu_head.glsl"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
float op(float a, float b) {
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
return 0.5f * a * (1.0f + erf_approx) * b;
}
#include "glu_main.glsl"

View File

@@ -1,11 +0,0 @@
#version 450
#include "glu_head.glsl"
const float GELU_QUICK_COEF = -1.702f;
float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
#include "glu_main.glsl"

View File

@@ -1,25 +0,0 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float xi = float(data_a[i]);
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
}

View File

@@ -1,39 +0,0 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float a = float(data_a[i]);
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
const float erf_approx = sign_x * y;
data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
}

View File

@@ -1,23 +0,0 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const float GELU_QUICK_COEF = -1.702f;
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x))));
}

View File

@@ -1,51 +0,0 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.glsl"
#include "utils.glsl"
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
uint misalign_offsets;
float param1; float param2; int param3;
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
// true if src0/src1 are the same shape and the indices can be reused without additional modulus
layout(constant_id = 0) const bool norepeat = false;
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
}
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
}
uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
if (norepeat) {
return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
} else {
return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
}
}
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
}

View File

@@ -1,9 +0,0 @@
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint KX;
uint KY;
float param1;
float param2;
} p;

Some files were not shown because too many files have changed in this diff Show More