diff --git a/.dockerignore b/.dockerignore index fada7a9b4..76704c36d 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,9 +3,7 @@ ollama app macapp dist -llm/llama.cpp .env .cache test_data -llm/build llama/build diff --git a/.gitattributes b/.gitattributes index f1c8bcb4d..51635caa7 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,11 @@ -llm/ext_server/* linguist-vendored +llama/**/*.cpp linguist-vendored +llama/**/*.hpp linguist-vendored +llama/**/*.h linguist-vendored +llama/**/*.c linguist-vendored +llama/**/*.cu linguist-vendored +llama/**/*.cuh linguist-vendored +llama/**/*.m linguist-vendored +llama/**/*.metal linguist-vendored + * text=auto *.go text eol=lf diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 4090f2066..50177050c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -1,5 +1,9 @@ name: release +env: + ROCM_WINDOWS_URL: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe + MSYS2_URL: https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe + on: push: tags: @@ -8,7 +12,7 @@ on: jobs: # Full build of the Mac assets build-darwin: - runs-on: macos-12 + runs-on: macos-13 environment: release steps: - uses: actions/checkout@v4 @@ -39,8 +43,8 @@ jobs: APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }} APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }} APPLE_ID: ${{ vars.APPLE_ID }} - SDKROOT: /Applications/Xcode_13.4.1.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk - DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer + SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk + DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer run: | ./scripts/build_darwin.sh @@ -48,8 +52,8 @@ jobs: with: name: dist-darwin path: | - dist/*arwin* - !dist/*-cov + dist/Ollama-darwin.zip + dist/ollama-darwin # Windows builds take a long time to both install the dependencies and build, so parallelize # CPU generation step @@ -60,51 +64,33 @@ jobs: KEY_CONTAINER: ${{ vars.KEY_CONTAINER }} steps: - uses: actions/checkout@v4 + - name: Set make jobs default + run: | + echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append - name: Set Version shell: bash run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV - - uses: 'google-github-actions/auth@v2' - with: - project_id: 'ollama' - credentials_json: '${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}' - - run: echo "${{ vars.OLLAMA_CERT }}" > ollama_inc.crt - - name: install Windows SDK 8.1 to get signtool + - name: Add msys paths run: | - $ErrorActionPreference = "Stop" - write-host "downloading SDK" - Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${env:RUNNER_TEMP}\sdksetup.exe" - Start-Process "${env:RUNNER_TEMP}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait - write-host "Win SDK 8.1 installed" - gci -path 'C:\Program Files (x86)\Windows Kits\' -r -fi 'signtool.exe' - - name: install signing plugin + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools run: | - $ErrorActionPreference = "Stop" - write-host "downloading plugin" - Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${env:RUNNER_TEMP}\plugin.zip" - Expand-Archive -Path "${env:RUNNER_TEMP}\plugin.zip" -DestinationPath ${env:RUNNER_TEMP}\plugin\ - write-host "Installing plugin" - & "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet - write-host "plugin installed" + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait - uses: actions/setup-go@v5 with: go-version-file: go.mod cache: true - - run: go get ./... - run: | - $gopath=(get-command go).source | split-path -parent - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$env:PATH" - go generate -x ./... - name: go generate + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } + make dist + name: make - uses: actions/upload-artifact@v4 with: name: generate-windows-cpu path: | - build/**/* - build/**/*.a - llm/build/**/*.a dist/windows-amd64/** # ROCm generation step @@ -115,74 +101,54 @@ jobs: KEY_CONTAINER: ${{ vars.KEY_CONTAINER }} steps: - uses: actions/checkout@v4 + - name: Set make jobs default + run: | + echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append - name: Set Version shell: bash run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV - - uses: 'google-github-actions/auth@v2' - with: - project_id: 'ollama' - credentials_json: '${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}' - - run: echo "${{ vars.OLLAMA_CERT }}" > ollama_inc.crt - - name: install Windows SDK 8.1 to get signtool + - name: Add msys paths run: | - $ErrorActionPreference = "Stop" - write-host "downloading SDK" - Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${env:RUNNER_TEMP}\sdksetup.exe" - Start-Process "${env:RUNNER_TEMP}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait - write-host "Win SDK 8.1 installed" - gci -path 'C:\Program Files (x86)\Windows Kits\' -r -fi 'signtool.exe' - - name: install signing plugin + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools run: | - $ErrorActionPreference = "Stop" - write-host "downloading plugin" - Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${env:RUNNER_TEMP}\plugin.zip" - Expand-Archive -Path "${env:RUNNER_TEMP}\plugin.zip" -DestinationPath ${env:RUNNER_TEMP}\plugin\ - write-host "Installing plugin" - & "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet - write-host "plugin installed" + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait - uses: actions/setup-go@v5 with: go-version-file: go.mod cache: true - - name: 'Install ROCm' + # ROCM installation steps + - name: 'Cache ROCm installer' + id: cache-rocm + uses: actions/cache@v4 + with: + path: rocm-install.exe + key: ${{ env.ROCM_WINDOWS_URL }} + - name: 'Conditionally Download ROCm' + if: steps.cache-rocm.outputs.cache-hit != 'true' run: | $ErrorActionPreference = "Stop" - write-host "downloading AMD HIP Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" - write-host "Installing AMD HIP" - Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait - write-host "Completed AMD HIP" + Invoke-WebRequest -Uri "${env:ROCM_WINDOWS_URL}" -OutFile "rocm-install.exe" + - name: 'Install ROCm' + run: | + Start-Process "rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait - name: 'Verify ROCm' run: | & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version - - run: go get ./... - - run: | - $gopath=(get-command go).source | split-path -parent - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$env:PATH" - $env:OLLAMA_SKIP_CPU_GENERATE="1" - $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) - go generate -x ./... - name: go generate - - name: 'gather rocm dependencies' + echo "HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path | select -first 1)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + - name: make rocm runner run: | - $HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) - md "dist\deps\bin\rocblas\library" - cp "${HIP_PATH}\bin\hipblas.dll" "dist\deps\bin\" - cp "${HIP_PATH}\bin\rocblas.dll" "dist\deps\bin\" - cp "${HIP_PATH}\bin\rocblas\library\*" "dist\deps\bin\rocblas\library\" + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } + make help-runners + make dist_rocm - uses: actions/upload-artifact@v4 with: name: generate-windows-rocm path: | - build/**/* dist/windows-amd64/** - - uses: actions/upload-artifact@v4 - with: - name: windows-rocm-deps - path: dist/deps/* # CUDA generation step generate-windows-cuda: @@ -191,88 +157,79 @@ jobs: strategy: matrix: cuda: - - version: "11" - url: 'https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe' - - version: "12" - url: 'https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe' + - version: "11.3" + url: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe + - version: "12.4" + url: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe env: KEY_CONTAINER: ${{ vars.KEY_CONTAINER }} steps: - uses: actions/checkout@v4 + - name: Set make jobs default + run: | + echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append - name: Set Version shell: bash run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV - - uses: 'google-github-actions/auth@v2' - with: - project_id: 'ollama' - credentials_json: '${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}' - - run: echo "${{ vars.OLLAMA_CERT }}" > ollama_inc.crt - - name: install Windows SDK 8.1 to get signtool + - name: Install msys2 run: | - $ErrorActionPreference = "Stop" - write-host "downloading SDK" - Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${env:RUNNER_TEMP}\sdksetup.exe" - Start-Process "${env:RUNNER_TEMP}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait - write-host "Win SDK 8.1 installed" - gci -path 'C:\Program Files (x86)\Windows Kits\' -r -fi 'signtool.exe' - - name: install signing plugin + $msys2_url="https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe" + write-host "Downloading msys2" + Invoke-WebRequest -Uri "${msys2_url}" -OutFile "${env:RUNNER_TEMP}\msys2.exe" + write-host "Installing msys2" + Start-Process "${env:RUNNER_TEMP}\msys2.exe" -ArgumentList @("in", "--confirm-command", "--accept-messages", "--root", "C:/msys64") -NoNewWindow -Wait + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools run: | - $ErrorActionPreference = "Stop" - write-host "downloading plugin" - Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${env:RUNNER_TEMP}\plugin.zip" - Expand-Archive -Path "${env:RUNNER_TEMP}\plugin.zip" -DestinationPath ${env:RUNNER_TEMP}\plugin\ - write-host "Installing plugin" - & "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet - write-host "plugin installed" + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang", "make") -NoNewWindow -Wait + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: verify tools + run: | + get-command gcc + gcc --version + get-command make + make --version - uses: actions/setup-go@v5 with: go-version-file: go.mod cache: true - - name: 'Install CUDA ${{ matrix.cuda.version }}' + # CUDA installation steps + - name: 'Cache CUDA installer' + id: cache-cuda + uses: actions/cache@v4 + with: + path: cuda-install.exe + key: ${{ matrix.cuda.url }} + - name: 'Conditionally Download CUDA' + if: steps.cache-cuda.outputs.cache-hit != 'true' run: | $ErrorActionPreference = "Stop" - write-host "downloading CUDA Installer" - Invoke-WebRequest -Uri "${{ matrix.cuda.url }}" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe" - write-host "Installing CUDA" - Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait - write-host "Completed CUDA" + Invoke-WebRequest -Uri "${{ matrix.cuda.url }}" -OutFile "cuda-install.exe" + - name: 'Install CUDA' + run: | + $subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | foreach-object {"${_}_${{ matrix.cuda.version }}"} + Start-Process "cuda-install.exe" -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait + - name: 'Verify CUDA' + run: | + & (resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0] --version $cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path) $cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2' - echo "$cudaPath\bin" >> $env:GITHUB_PATH - echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV - echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV - echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV - - name: 'Verify CUDA' - run: nvcc -V - - run: go get ./... - - name: go generate + echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "CUDA_PATH_V${cudaVer}=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + + - name: make cuda runner run: | - $gopath=(get-command go).source | split-path -parent - $cudabin=(get-command nvcc).source | split-path - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$cudabin;$env:PATH" - $env:OLLAMA_SKIP_CPU_GENERATE="1" - go generate -x ./... - - name: 'gather cuda dependencies' - run: | - $NVIDIA_DIR=(resolve-path 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*\bin\')[0] - md "dist\deps" - cp "${NVIDIA_DIR}\cudart64_*.dll" "dist\deps\" - cp "${NVIDIA_DIR}\cublas64_*.dll" "dist\deps\" - cp "${NVIDIA_DIR}\cublasLt64_*.dll" "dist\deps\" + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } + make dist_cuda_v$(($env:CUDA_PATH | split-path -leaf) -replace 'v(\d+).*', '$1') - uses: actions/upload-artifact@v4 with: name: generate-windows-cuda-${{ matrix.cuda.version }} path: | - build/**/* dist/windows-amd64/** - - uses: actions/upload-artifact@v4 - with: - name: windows-cuda-deps-${{ matrix.cuda.version }} - path: dist/deps/* - # windows arm64 generate, go build, and zip file (no installer) # Output of this build is aggregated into the final x86 build @@ -292,6 +249,30 @@ jobs: choco install -y --no-progress git gzip echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "C:\ProgramData\chocolatey\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + # pacman is buggy on win arm64, so we avoid using it, but rely on the binary artifacts + # we download the sfx (7zip bundle) which isn't fully set up, but the binaries we need to build work + - name: Install msys2 x64 + run: | + $url="https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-base-x86_64-20240727.sfx.exe" + write-host "Downloading MSYS2" + Invoke-WebRequest -Uri "$url" -outfile "${env:RUNNER_TEMP}\msys2.exe" + write-host "Installing msys2" + Start-Process "${env:RUNNER_TEMP}\msys2.exe" -ArgumentList @( + '-y', '-oC:\' + ) -NoNewWindow -Wait + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + # since pacman isn't reliable, we just download the tar file and extract directly + - name: Downloading and extracting msys2 make tar file + run: | + $url="https://mirror.msys2.org/msys/x86_64/make-4.4.1-2-x86_64.pkg.tar.zst" + write-host "Downloading make" + Invoke-WebRequest -Uri "$url" -outfile c:\msys64\make.tar.zst + cd c:\msys64; tar -xf make.tar.zst + rm c:\msys64\make.tar.zst + - name: Verify Make works properly + run: | + echo $env:PATH + make --version - name: Install Visual Studio 2022 run: | $components = @( @@ -354,7 +335,7 @@ jobs: - name: Set Version run: | $ver=${env:GITHUB_REF_NAME}.trim("v") - write-host VERSION=$ver | Out-File -FilePath ${env:GITHUB_ENV} -Encoding utf8 -Append + echo VERSION=$ver | Out-File -FilePath ${env:GITHUB_ENV} -Encoding utf8 -Append - uses: 'google-github-actions/auth@v2' with: project_id: 'ollama' @@ -385,13 +366,12 @@ jobs: - run: | $gopath=(get-command go).source | split-path -parent $gccpath=(get-command gcc).source | split-path -parent - & "C:\Program Files\Microsoft Visual Studio\2022\Community\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$gccpath;$env:PATH;C:\Program Files\Microsoft Visual Studio\2022\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin" + import-module 'C:\Program Files\Microsoft Visual Studio\2022\Community\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -Arch arm64 -vsinstallpath 'C:\Program Files\Microsoft Visual Studio\2022\Community' -skipautomaticlocation + $env:PATH="$gopath;$gccpath;$env:PATH" echo $env:PATH $env:ARCH="arm64" - .\scripts\build_windows.ps1 buildOllama buildApp gatherDependencies distZip + .\scripts\build_windows.ps1 buildOllama buildApp gatherDependencies sign distZip name: 'Windows Build' - uses: actions/upload-artifact@v4 with: @@ -441,6 +421,24 @@ jobs: write-host "Installing plugin" & "${env:RUNNER_TEMP}\plugin\*\kmscng.msi" /quiet write-host "plugin installed" + - name: Install msys2 + run: | + $msys2_url="https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe" + write-host "Downloading msys2" + Invoke-WebRequest -Uri "${msys2_url}" -OutFile "${env:RUNNER_TEMP}\msys2.exe" + write-host "Installing msys2" + Start-Process "${env:RUNNER_TEMP}\msys2.exe" -ArgumentList @("in", "--confirm-command", "--accept-messages", "--root", "C:/msys64") -NoNewWindow -Wait + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools + run: | + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang", "make") -NoNewWindow -Wait + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: verify tools + run: | + get-command gcc + gcc --version + get-command make + make --version - uses: actions/setup-go@v5 with: go-version-file: go.mod @@ -449,36 +447,29 @@ jobs: - uses: actions/download-artifact@v4 with: name: generate-windows-cpu + path: dist/windows-amd64/ - uses: actions/download-artifact@v4 with: - name: generate-windows-cuda-11 + name: generate-windows-cuda-11.3 + path: dist/windows-amd64/ - uses: actions/download-artifact@v4 with: - name: generate-windows-cuda-12 - - uses: actions/download-artifact@v4 - with: - name: windows-cuda-deps-11 - - uses: actions/download-artifact@v4 - with: - name: windows-cuda-deps-12 - - uses: actions/download-artifact@v4 - with: - name: windows-rocm-deps + name: generate-windows-cuda-12.4 + path: dist/windows-amd64/ - uses: actions/download-artifact@v4 with: name: generate-windows-rocm + path: dist/windows-amd64/ - uses: actions/download-artifact@v4 with: name: windows-arm64 path: dist - - run: dir build - run: | - $gopath=(get-command go).source | split-path -parent - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$env:PATH" + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' $env:OLLAMA_SKIP_GENERATE="1" + $env:ARCH="amd64" + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } & .\scripts\build_windows.ps1 - uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 26dc732af..8dcc506ba 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,5 +1,11 @@ name: test +env: + ROCM_WINDOWS_URL: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe + MSYS2_URL: https://github.com/msys2/msys2-installer/releases/download/2024-07-27/msys2-x86_64-20240727.exe + CUDA_12_WINDOWS_URL: https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_551.61_windows.exe + CUDA_12_WINDOWS_VER: 12.4 + concurrency: # For PRs, later CI runs preempt previous ones. e.g. a force push on a PR # cancels running CI jobs and starts all new ones. @@ -21,9 +27,7 @@ jobs: changes: runs-on: ubuntu-latest outputs: - GENERATE: ${{ steps.changes.outputs.GENERATE }} - GENERATE_CUDA: ${{ steps.changes.outputs.GENERATE_CUDA }} - GENERATE_ROCM: ${{ steps.changes.outputs.GENERATE_ROCM }} + RUNNERS: ${{ steps.changes.outputs.RUNNERS }} steps: - uses: actions/checkout@v4 with: @@ -38,14 +42,167 @@ jobs: } { - echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**') - echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**') - echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**') + echo RUNNERS=$(changed 'llama/**') } >>$GITHUB_OUTPUT - generate: + runners-linux-cuda: needs: [changes] - if: ${{ needs.changes.outputs.GENERATE == 'True' }} + if: ${{ needs.changes.outputs.RUNNERS == 'True' }} + strategy: + matrix: + cuda-version: + - '11.8.0' + runs-on: linux + container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04 + steps: + - run: | + apt-get update && apt-get install -y git build-essential curl + env: + DEBIAN_FRONTEND: noninteractive + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: go.mod + cache: true + - run: go get ./... + - run: | + git config --global --add safe.directory /__w/ollama/ollama + cores=$(grep '^core id' /proc/cpuinfo |sort -u|wc -l) + make -j $cores cuda_v11 + runners-linux-rocm: + needs: [changes] + if: ${{ needs.changes.outputs.RUNNERS == 'True' }} + strategy: + matrix: + rocm-version: + - '6.1.2' + runs-on: linux + container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} + steps: + - run: | + apt-get update && apt-get install -y git build-essential curl rocm-libs + env: + DEBIAN_FRONTEND: noninteractive + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: go.mod + cache: true + - run: go get ./... + - run: | + git config --global --add safe.directory /__w/ollama/ollama + cores=$(grep '^core id' /proc/cpuinfo |sort -u|wc -l) + make -j $cores rocm + + # ROCm generation step + runners-windows-rocm: + needs: [changes] + if: ${{ needs.changes.outputs.RUNNERS == 'True' }} + runs-on: windows + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: Set make jobs default + run: | + echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + + # ROCM installation steps + - name: 'Cache ROCm installer' + id: cache-rocm + uses: actions/cache@v4 + with: + path: rocm-install.exe + key: ${{ env.ROCM_WINDOWS_URL }} + - name: 'Conditionally Download ROCm' + if: steps.cache-rocm.outputs.cache-hit != 'true' + run: | + $ErrorActionPreference = "Stop" + Invoke-WebRequest -Uri "${env:ROCM_WINDOWS_URL}" -OutFile "rocm-install.exe" + - name: 'Install ROCm' + run: | + Start-Process "rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait + - name: 'Verify ROCm' + run: | + & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + echo "HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path | select -first 1)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + + - name: Add msys paths + run: | + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools + run: | + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait + + - name: make rocm runner + run: | + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } + make -C llama print-HIP_PATH print-HIP_LIB_DIR + make rocm + + # CUDA generation step + runners-windows-cuda: + needs: [changes] + if: ${{ needs.changes.outputs.RUNNERS == 'True' }} + runs-on: windows + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: Set make jobs default + run: | + echo "MAKEFLAGS=--jobs=$((Get-ComputerInfo -Property CsProcessors).CsProcessors.NumberOfCores)" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + + # CUDA installation steps + - name: 'Cache CUDA installer' + id: cache-cuda + uses: actions/cache@v4 + with: + path: cuda-install.exe + key: ${{ env.CUDA_12_WINDOWS_URL }} + - name: 'Conditionally Download CUDA' + if: steps.cache-cuda.outputs.cache-hit != 'true' + run: | + $ErrorActionPreference = "Stop" + Invoke-WebRequest -Uri "${env:CUDA_12_WINDOWS_URL}" -OutFile "cuda-install.exe" + - name: 'Install CUDA' + run: | + $subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | foreach-object {"${_}_${{ env.CUDA_12_WINDOWS_VER }}"} + Start-Process "cuda-install.exe" -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait + - name: 'Verify CUDA' + run: | + & (resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0] --version + $cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path) + $cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2' + echo "$cudaPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "CUDA_PATH_V${cudaVer}=$cudaPath" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append + + - name: Add msys paths + run: | + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools + run: | + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait + - name: make cuda runner + run: | + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } + make cuda_v$(($env:CUDA_PATH | split-path -leaf) -replace 'v(\d+).*', '$1') + + runners-cpu: + needs: [changes] + if: ${{ needs.changes.outputs.RUNNERS == 'True' }} strategy: matrix: os: [ubuntu-latest, macos-latest, windows-2019] @@ -58,6 +215,7 @@ jobs: runs-on: ${{ matrix.os }} env: GOARCH: ${{ matrix.arch }} + ARCH: ${{ matrix.arch }} CGO_ENABLED: '1' steps: - uses: actions/checkout@v4 @@ -65,153 +223,31 @@ jobs: with: go-version-file: go.mod cache: true - - run: go get ./... - - run: | + - name: Add msys paths + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait + - name: 'Build Windows Go Runners' + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | $gopath=(get-command go).source | split-path -parent $gccpath=(get-command gcc).source | split-path -parent - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE + import-module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' + Enter-VsDevShell -vsinstallpath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -skipautomaticlocation -DevCmdArguments '-arch=x64 -no_logo' $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" $env:PATH="$gopath;$gccpath;$env:PATH" echo $env:PATH - go generate -x ./... - if: ${{ startsWith(matrix.os, 'windows-') }} - name: 'Windows Go Generate' - - run: go generate -x ./... + if (!(gcc --version | select-string -quiet clang)) { throw "wrong gcc compiler detected - must be clang" } + make -j 4 + - name: 'Build Unix Go Runners' if: ${{ ! startsWith(matrix.os, 'windows-') }} - name: 'Unix Go Generate' + run: make -j 4 - run: go build . - generate-cuda: - needs: [changes] - if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }} - strategy: - matrix: - cuda-version: - - '11.8.0' - runs-on: linux - container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04 - steps: - - run: | - apt-get update && apt-get install -y git build-essential curl - curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \ - | tar -zx -C /usr --strip-components 1 - env: - DEBIAN_FRONTEND: noninteractive - - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 - with: - go-version-file: go.mod - cache: true - - run: go get ./... - - run: | - git config --global --add safe.directory /__w/ollama/ollama - go generate -x ./... - env: - OLLAMA_SKIP_CPU_GENERATE: '1' - generate-rocm: - needs: [changes] - if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }} - strategy: - matrix: - rocm-version: - - '6.1.2' - runs-on: linux - container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} - steps: - - run: | - apt-get update && apt-get install -y git build-essential curl rocm-libs - curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \ - | tar -zx -C /usr --strip-components 1 - env: - DEBIAN_FRONTEND: noninteractive - - uses: actions/checkout@v4 - - uses: actions/setup-go@v4 - with: - go-version-file: go.mod - cache: true - - run: go get ./... - - run: | - git config --global --add safe.directory /__w/ollama/ollama - go generate -x ./... - env: - OLLAMA_SKIP_CPU_GENERATE: '1' - - # ROCm generation step - generate-windows-rocm: - needs: [changes] - if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }} - runs-on: windows - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: go.mod - cache: true - - name: 'Install ROCm' - run: | - $ErrorActionPreference = "Stop" - write-host "downloading AMD HIP Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" - write-host "Installing AMD HIP" - Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait - write-host "Completed AMD HIP" - - name: 'Verify ROCm' - run: | - & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version - - run: go get ./... - - run: | - $gopath=(get-command go).source | split-path -parent - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$env:PATH" - $env:OLLAMA_SKIP_CPU_GENERATE="1" - $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) - go generate -x ./... - name: go generate - env: - OLLAMA_SKIP_CPU_GENERATE: '1' - - # CUDA generation step - generate-windows-cuda: - needs: [changes] - if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }} - runs-on: windows - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-go@v5 - with: - go-version-file: go.mod - cache: true - - name: 'Install CUDA' - run: | - $ErrorActionPreference = "Stop" - write-host "downloading CUDA Installer" - Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe" - write-host "Installing CUDA" - Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait - write-host "Completed CUDA" - $cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path) - $cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2' - echo "$cudaPath\bin" >> $env:GITHUB_PATH - echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV - echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV - echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV - - name: 'Verify CUDA' - run: nvcc -V - - run: go get ./... - - name: go generate - run: | - $gopath=(get-command go).source | split-path -parent - $cudabin=(get-command nvcc).source | split-path - & "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1" - cd $env:GITHUB_WORKSPACE - $env:CMAKE_SYSTEM_VERSION="10.0.22621.0" - $env:PATH="$gopath;$cudabin;$env:PATH" - $env:OLLAMA_SKIP_CPU_GENERATE="1" - go generate -x ./... - env: - OLLAMA_SKIP_CPU_GENERATE: '1' lint: strategy: @@ -233,6 +269,15 @@ jobs: - uses: actions/checkout@v4 with: submodules: recursive + - name: Add msys paths + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait - uses: actions/setup-go@v5 with: go-version-file: go.mod @@ -245,7 +290,7 @@ jobs: shell: bash - uses: golangci/golangci-lint-action@v6 with: - args: --timeout 8m0s -v + args: --timeout 10m0s -v test: strategy: matrix: @@ -260,13 +305,19 @@ jobs: env: GOARCH: ${{ matrix.arch }} CGO_ENABLED: '1' - OLLAMA_CPU_TARGET: 'static' - OLLAMA_SKIP_CPU_GENERATE: '1' - OLLAMA_SKIP_METAL_GENERATE: '1' steps: - uses: actions/checkout@v4 with: submodules: recursive + - name: Add msys paths + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | + echo "c:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + - name: Install msys2 tools + if: ${{ startsWith(matrix.os, 'windows-') }} + run: | + Start-Process "c:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait - uses: actions/setup-go@v5 with: go-version-file: go.mod @@ -277,6 +328,16 @@ jobs: arm64) echo ARCH=arm64 ;; esac >>$GITHUB_ENV shell: bash - - run: go generate ./... - - run: go build - - run: go test -v ./... + - run: go test ./... + + patches: + needs: [changes] + if: ${{ needs.changes.outputs.RUNNERS == 'True' }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - name: Verify patches carry all the changes + run: | + make apply-patches sync && git diff --compact-summary --exit-code llama diff --git a/.gitignore b/.gitignore index 87f8b0072..caa62a524 100644 --- a/.gitignore +++ b/.gitignore @@ -5,14 +5,11 @@ .swp dist ollama -ggml-metal.metal .cache *.exe .idea test_data *.crt -llm/build -build/*/*/* -!build/**/placeholder llama/build -__debug_bin* \ No newline at end of file +__debug_bin* +llama/vendor \ No newline at end of file diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index b92f645d7..000000000 --- a/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "llama.cpp"] - path = llm/llama.cpp - url = https://github.com/ggerganov/llama.cpp.git - shallow = true \ No newline at end of file diff --git a/.golangci.yaml b/.golangci.yaml index 2e0ed3c7b..9d59fd6c0 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -8,8 +8,6 @@ linters: - containedctx - contextcheck - errcheck - - exportloopref - - gci - gocheckcompilerdirectives - gofmt - gofumpt @@ -30,8 +28,6 @@ linters: - wastedassign - whitespace linters-settings: - gci: - sections: [standard, default, localmodule] staticcheck: checks: - all diff --git a/Dockerfile b/Dockerfile index 0f43e618d..47228df61 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,254 +1,197 @@ -ARG GOLANG_VERSION=1.22.5 -ARG CMAKE_VERSION=3.22.1 +ARG GOLANG_VERSION=1.22.8 ARG CUDA_VERSION_11=11.3.1 -ARG CUDA_V11_ARCHITECTURES="50;52;53;60;61;62;70;72;75;80;86" ARG CUDA_VERSION_12=12.4.0 -ARG CUDA_V12_ARCHITECTURES="60;61;62;70;72;75;80;86;87;89;90;90a" ARG ROCM_VERSION=6.1.2 +ARG JETPACK_6=r36.2.0 +ARG JETPACK_5=r35.4.1 -# Copy the minimal context we need to run the generate scripts -FROM scratch AS llm-code -COPY .git .git -COPY .gitmodules .gitmodules -COPY llm llm - -FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION_11-devel-centos7 AS cuda-11-build-amd64 -ARG CMAKE_VERSION -COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:$PATH -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -WORKDIR /go/src/github.com/ollama/ollama/llm/generate -ARG CGO_CFLAGS -ARG CUDA_V11_ARCHITECTURES -ENV GOARCH=amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 \ - OLLAMA_SKIP_CPU_GENERATE=1 \ - CMAKE_CUDA_ARCHITECTURES="${CUDA_V11_ARCHITECTURES}" \ - CUDA_VARIANT="_v11" \ - bash gen_linux.sh - -FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION_12-devel-centos7 AS cuda-12-build-amd64 -ARG CMAKE_VERSION -COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:$PATH -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -WORKDIR /go/src/github.com/ollama/ollama/llm/generate -ARG CGO_CFLAGS -ARG CUDA_V12_ARCHITECTURES -ENV GOARCH=amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 \ - OLLAMA_SKIP_CPU_GENERATE=1 \ - CMAKE_CUDA_ARCHITECTURES="${CUDA_V12_ARCHITECTURES}" \ - CUDA_VARIANT="_v12" \ - OLLAMA_CUSTOM_CUDA_DEFS="-DGGML_CUDA_USE_GRAPHS=on" \ - bash gen_linux.sh - -FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION_11-devel-rockylinux8 AS cuda-11-build-runner-arm64 -ARG CMAKE_VERSION -COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -WORKDIR /go/src/github.com/ollama/ollama/llm/generate -ARG CGO_CFLAGS -ARG CUDA_V11_ARCHITECTURES -ENV GOARCH=arm64 -RUN OLLAMA_SKIP_STATIC_GENERATE=1 \ - OLLAMA_SKIP_CPU_GENERATE=1 \ - CMAKE_CUDA_ARCHITECTURES="${CUDA_V11_ARCHITECTURES}" \ - CUDA_VARIANT="_v11" \ - bash gen_linux.sh - -FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION_12-devel-rockylinux8 AS cuda-12-build-runner-arm64 -ARG CMAKE_VERSION -COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -WORKDIR /go/src/github.com/ollama/ollama/llm/generate -ARG CGO_CFLAGS -ARG CUDA_V12_ARCHITECTURES -ENV GOARCH=arm64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 \ - OLLAMA_SKIP_CPU_GENERATE=1 \ - CMAKE_CUDA_ARCHITECTURES="${CUDA_V12_ARCHITECTURES}" \ - CUDA_VARIANT="_v12" \ - OLLAMA_CUSTOM_CUDA_DEFS="-DGGML_CUDA_USE_GRAPHS=on" \ - bash gen_linux.sh - - -FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS rocm-build-amd64 -ARG CMAKE_VERSION -COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:$PATH -ENV LIBRARY_PATH=/opt/amdgpu/lib64 -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -WORKDIR /go/src/github.com/ollama/ollama/llm/generate -ARG CGO_CFLAGS -ARG AMDGPU_TARGETS -ENV GOARCH=amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_SKIP_CPU_GENERATE=1 bash gen_linux.sh -RUN mkdir -p ../../dist/linux-amd64-rocm/lib/ollama && \ - (cd /opt/rocm/lib && tar cf - rocblas/library) | (cd ../../dist/linux-amd64-rocm/lib/ollama && tar xf - ) - -FROM --platform=linux/amd64 centos:7 AS cpu-builder-amd64 -ARG CMAKE_VERSION +### To create a local image for building linux binaries on mac or windows with efficient incremental builds +# +# docker build --platform linux/amd64 -t builder-amd64 -f Dockerfile --target unified-builder-amd64 . +# docker run --platform linux/amd64 --rm -it -v $(pwd):/go/src/github.com/ollama/ollama/ builder-amd64 +# +### Then incremental builds will be much faster in this container +# +# make -j 10 dist +# +FROM --platform=linux/amd64 rocm/dev-centos-7:${ROCM_VERSION}-complete AS unified-builder-amd64 ARG GOLANG_VERSION +ARG CUDA_VERSION_11 +ARG CUDA_VERSION_12 COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/devtoolset-10/root/usr/bin:$PATH -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -ARG OLLAMA_CUSTOM_CPU_DEFS -ARG CGO_CFLAGS -ENV GOARCH=amd64 -WORKDIR /go/src/github.com/ollama/ollama/llm/generate +ENV PATH /opt/rh/devtoolset-10/root/usr/bin:/usr/local/cuda/bin:$PATH +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64 +RUN GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh +RUN yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo && \ + dnf clean all && \ + dnf install -y \ + zsh \ + cuda-toolkit-$(echo ${CUDA_VERSION_11} | cut -f1-2 -d. | sed -e "s/\./-/g") \ + cuda-toolkit-$(echo ${CUDA_VERSION_12} | cut -f1-2 -d. | sed -e "s/\./-/g") +# TODO intel oneapi goes here... +ENV GOARCH amd64 +ENV CGO_ENABLED 1 +WORKDIR /go/src/github.com/ollama/ollama/ +ENTRYPOINT [ "zsh" ] -FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_CPU_TARGET="static" bash gen_linux.sh -FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" bash gen_linux.sh -FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx" bash gen_linux.sh -FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu_avx2" bash gen_linux.sh - -FROM --platform=linux/arm64 rockylinux:8 AS cpu-builder-arm64 -ARG CMAKE_VERSION +### To create a local image for building linux binaries on mac or linux/arm64 with efficient incremental builds +# Note: this does not contain jetson variants +# +# docker build --platform linux/arm64 -t builder-arm64 -f Dockerfile --target unified-builder-arm64 . +# docker run --platform linux/arm64 --rm -it -v $(pwd):/go/src/github.com/ollama/ollama/ builder-arm64 +# +FROM --platform=linux/arm64 rockylinux:8 AS unified-builder-arm64 ARG GOLANG_VERSION +ARG CUDA_VERSION_11 +ARG CUDA_VERSION_12 COPY ./scripts/rh_linux_deps.sh / -RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh -ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH -COPY --from=llm-code / /go/src/github.com/ollama/ollama/ -ARG OLLAMA_CUSTOM_CPU_DEFS -ARG CGO_CFLAGS -ENV GOARCH=arm64 -WORKDIR /go/src/github.com/ollama/ollama/llm/generate +RUN GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh +RUN yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo && \ + dnf config-manager --set-enabled appstream && \ + dnf clean all && \ + dnf install -y \ + zsh \ + cuda-toolkit-$(echo ${CUDA_VERSION_11} | cut -f1-2 -d. | sed -e "s/\./-/g") \ + cuda-toolkit-$(echo ${CUDA_VERSION_12} | cut -f1-2 -d. | sed -e "s/\./-/g") +ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH:/usr/local/cuda/bin +ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64 +ENV LIBRARY_PATH=/usr/local/cuda/lib64/stubs:/opt/amdgpu/lib64 +ENV GOARCH arm64 +ENV CGO_ENABLED 1 +WORKDIR /go/src/github.com/ollama/ollama/ +ENTRYPOINT [ "zsh" ] -FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_CPU_TARGET="static" bash gen_linux.sh -FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64 -RUN --mount=type=cache,target=/root/.ccache \ - OLLAMA_SKIP_STATIC_GENERATE=1 OLLAMA_CPU_TARGET="cpu" bash gen_linux.sh - - -# Intermediate stages used for ./scripts/build_linux.sh -FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64 -ENV CGO_ENABLED=1 -WORKDIR /go/src/github.com/ollama/ollama +FROM --platform=linux/amd64 unified-builder-amd64 AS build-amd64 COPY . . -COPY --from=static-build-amd64 /go/src/github.com/ollama/ollama/llm/build/ llm/build/ -COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/build/ build/ -COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/build/ build/ -COPY --from=cuda-11-build-amd64 /go/src/github.com/ollama/ollama/dist/ dist/ -COPY --from=cuda-11-build-amd64 /go/src/github.com/ollama/ollama/build/ build/ -COPY --from=cuda-12-build-amd64 /go/src/github.com/ollama/ollama/dist/ dist/ -COPY --from=cuda-12-build-amd64 /go/src/github.com/ollama/ollama/build/ build/ -COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/dist/ dist/ -COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/build/ build/ -ARG GOFLAGS -ARG CGO_CFLAGS +ARG OLLAMA_SKIP_CUDA_GENERATE +ARG OLLAMA_SKIP_ROCM_GENERATE +ARG OLLAMA_FAST_BUILD +ARG VERSION +ARG CUSTOM_CPU_FLAGS RUN --mount=type=cache,target=/root/.ccache \ - go build -trimpath -o dist/linux-amd64/bin/ollama . + if grep "^flags" /proc/cpuinfo|grep avx>/dev/null; then \ + make -j $(nproc) dist ; \ + else \ + make -j 5 dist ; \ + fi RUN cd dist/linux-$GOARCH && \ - tar --exclude runners -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz -RUN cd dist/linux-$GOARCH-rocm && \ - tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-rocm.tgz + tar -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz +RUN if [ -z ${OLLAMA_SKIP_ROCM_GENERATE} ] ; then \ + cd dist/linux-$GOARCH-rocm && \ + tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-rocm.tgz ;\ + fi -FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64 -ENV CGO_ENABLED=1 +# Jetsons need to be built in discrete stages +FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK_5} AS runners-jetpack5-arm64 ARG GOLANG_VERSION -WORKDIR /go/src/github.com/ollama/ollama +RUN apt-get update && apt-get install -y git curl ccache && \ + curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-arm64.tar.gz | tar xz -C /usr/local && \ + ln -s /usr/local/go/bin/go /usr/local/bin/go && \ + ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt && \ + apt-get clean && rm -rf /var/lib/apt/lists/* +WORKDIR /go/src/github.com/ollama/ollama/ COPY . . -COPY --from=static-build-arm64 /go/src/github.com/ollama/ollama/llm/build/ llm/build/ -COPY --from=cuda-11-build-runner-arm64 /go/src/github.com/ollama/ollama/dist/ dist/ -COPY --from=cuda-11-build-runner-arm64 /go/src/github.com/ollama/ollama/build/ build/ -COPY --from=cuda-12-build-runner-arm64 /go/src/github.com/ollama/ollama/dist/ dist/ -COPY --from=cuda-12-build-runner-arm64 /go/src/github.com/ollama/ollama/build/ build/ -ARG GOFLAGS ARG CGO_CFLAGS +ENV GOARCH arm64 +ARG VERSION RUN --mount=type=cache,target=/root/.ccache \ - go build -trimpath -o dist/linux-arm64/bin/ollama . + make -j 5 dist_cuda_v11 \ + CUDA_ARCHITECTURES="72;87" \ + GPU_RUNNER_VARIANT=_jetpack5 \ + DIST_LIB_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack5/lib/ollama \ + DIST_GPU_RUNNER_DEPS_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack5/lib/ollama/cuda_jetpack5 + +FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK_6} AS runners-jetpack6-arm64 +ARG GOLANG_VERSION +RUN apt-get update && apt-get install -y git curl ccache && \ + curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-arm64.tar.gz | tar xz -C /usr/local && \ + ln -s /usr/local/go/bin/go /usr/local/bin/go && \ + ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt && \ + apt-get clean && rm -rf /var/lib/apt/lists/* +WORKDIR /go/src/github.com/ollama/ollama/ +COPY . . +ARG CGO_CFLAGS +ENV GOARCH arm64 +ARG VERSION +RUN --mount=type=cache,target=/root/.ccache \ + make -j 5 dist_cuda_v12 \ + CUDA_ARCHITECTURES="87" \ + GPU_RUNNER_VARIANT=_jetpack6 \ + DIST_LIB_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack6/lib/ollama \ + DIST_GPU_RUNNER_DEPS_DIR=/go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack6/lib/ollama/cuda_jetpack6 + +FROM --platform=linux/arm64 unified-builder-arm64 AS build-arm64 +COPY . . +ARG OLLAMA_SKIP_CUDA_GENERATE +ARG OLLAMA_FAST_BUILD +ARG VERSION +RUN --mount=type=cache,target=/root/.ccache \ + make -j 5 dist +COPY --from=runners-jetpack5-arm64 /go/src/github.com/ollama/ollama/dist/ dist/ +COPY --from=runners-jetpack6-arm64 /go/src/github.com/ollama/ollama/dist/ dist/ RUN cd dist/linux-$GOARCH && \ - tar --exclude runners -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz + tar -cf - . | pigz --best > ../ollama-linux-$GOARCH.tgz +RUN cd dist/linux-$GOARCH-jetpack5 && \ + tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-jetpack5.tgz +RUN cd dist/linux-$GOARCH-jetpack6 && \ + tar -cf - . | pigz --best > ../ollama-linux-$GOARCH-jetpack6.tgz FROM --platform=linux/amd64 scratch AS dist-amd64 COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/ollama-linux-*.tgz / FROM --platform=linux/arm64 scratch AS dist-arm64 COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/ollama-linux-*.tgz / -FROM dist-$TARGETARCH as dist +FROM dist-$TARGETARCH AS dist -# Optimized container images do not cary nested payloads -FROM --platform=linux/amd64 static-build-amd64 AS container-build-amd64 -WORKDIR /go/src/github.com/ollama/ollama -COPY . . -ARG GOFLAGS -ARG CGO_CFLAGS -RUN --mount=type=cache,target=/root/.ccache \ - go build -trimpath -o dist/linux-amd64/bin/ollama . +# For amd64 container images, filter out cuda/rocm to minimize size +FROM build-amd64 AS runners-cuda-amd64 +RUN rm -rf \ + ./dist/linux-amd64/lib/ollama/libggml_hipblas.so \ + ./dist/linux-amd64/lib/ollama/runners/rocm* -FROM --platform=linux/arm64 static-build-arm64 AS container-build-arm64 -WORKDIR /go/src/github.com/ollama/ollama -COPY . . -ARG GOFLAGS -ARG CGO_CFLAGS -RUN --mount=type=cache,target=/root/.ccache \ - go build -trimpath -o dist/linux-arm64/bin/ollama . +FROM build-amd64 AS runners-rocm-amd64 +RUN rm -rf \ + ./dist/linux-amd64/lib/ollama/libggml_cuda*.so \ + ./dist/linux-amd64/lib/ollama/libcu*.so* \ + ./dist/linux-amd64/lib/ollama/runners/cuda* FROM --platform=linux/amd64 ubuntu:22.04 AS runtime-amd64 RUN apt-get update && \ apt-get install -y ca-certificates && \ apt-get clean && rm -rf /var/lib/apt/lists/* -COPY --from=container-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/ -COPY --from=cpu-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=cuda-11-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=cuda-12-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ +COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/ +COPY --from=runners-cuda-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ FROM --platform=linux/arm64 ubuntu:22.04 AS runtime-arm64 RUN apt-get update && \ apt-get install -y ca-certificates && \ apt-get clean && rm -rf /var/lib/apt/lists/* -COPY --from=container-build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/bin/ /bin/ -COPY --from=cpu-build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/lib/ /lib/ -COPY --from=cuda-11-build-runner-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/lib/ /lib/ -COPY --from=cuda-12-build-runner-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/lib/ /lib/ +COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/bin/ /bin/ +COPY --from=build-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64/lib/ /lib/ +COPY --from=runners-jetpack5-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack5/lib/ /lib/ +COPY --from=runners-jetpack6-arm64 /go/src/github.com/ollama/ollama/dist/linux-arm64-jetpack6/lib/ /lib/ + # ROCm libraries larger so we keep it distinct from the CPU/CUDA image FROM --platform=linux/amd64 ubuntu:22.04 AS runtime-rocm # Frontload the rocm libraries which are large, and rarely change to increase chance of a common layer # across releases -COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64-rocm/lib/ /lib/ +COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64-rocm/lib/ /lib/ RUN apt-get update && \ apt-get install -y ca-certificates && \ apt-get clean && rm -rf /var/lib/apt/lists/* -COPY --from=container-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/ -COPY --from=cpu-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ -COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ +COPY --from=build-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/bin/ /bin/ +COPY --from=runners-rocm-amd64 /go/src/github.com/ollama/ollama/dist/linux-amd64/lib/ /lib/ + EXPOSE 11434 -ENV OLLAMA_HOST=0.0.0.0 +ENV OLLAMA_HOST 0.0.0.0 ENTRYPOINT ["/bin/ollama"] CMD ["serve"] FROM runtime-$TARGETARCH EXPOSE 11434 -ENV OLLAMA_HOST=0.0.0.0 +ENV OLLAMA_HOST 0.0.0.0 ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..383354ee9 --- /dev/null +++ b/Makefile @@ -0,0 +1,103 @@ +# top level makefile for Ollama +include make/common-defs.make + + +# Determine which if any GPU runners we should build +include make/cuda-v11-defs.make +include make/cuda-v12-defs.make +include make/rocm-defs.make + +ifeq ($(CUSTOM_CPU_FLAGS),) +ifeq ($(ARCH),amd64) + RUNNER_TARGETS=cpu +endif +# Without CUSTOM_CPU_FLAGS we default to build both v11 and v12 if present +ifeq ($(OLLAMA_SKIP_CUDA_GENERATE),) +ifneq ($(CUDA_11_COMPILER),) + RUNNER_TARGETS += cuda_v11 +endif +ifneq ($(CUDA_12_COMPILER),) + RUNNER_TARGETS += cuda_v12 +endif +endif +else # CUSTOM_CPU_FLAGS is set, we'll build only the latest cuda version detected +ifneq ($(CUDA_12_COMPILER),) + RUNNER_TARGETS += cuda_v12 +else ifneq ($(CUDA_11_COMPILER),) + RUNNER_TARGETS += cuda_v11 +endif +endif + +ifeq ($(OLLAMA_SKIP_ROCM_GENERATE),) +ifneq ($(HIP_COMPILER),) + RUNNER_TARGETS += rocm +endif +endif + + +all: runners exe + +dist: $(addprefix dist_, $(RUNNER_TARGETS)) dist_exe + +dist_%: + @$(MAKE) --no-print-directory -f make/Makefile.$* dist + +runners: $(RUNNER_TARGETS) + +$(RUNNER_TARGETS): + @$(MAKE) --no-print-directory -f make/Makefile.$@ + +exe dist_exe: + @$(MAKE) --no-print-directory -f make/Makefile.ollama $@ + +help-sync apply-patches create-patches sync sync-clean: + @$(MAKE) --no-print-directory -f make/Makefile.sync $@ + +test integration lint: + @$(MAKE) --no-print-directory -f make/Makefile.test $@ + +clean: + rm -rf $(BUILD_DIR) $(DIST_LIB_DIR) $(OLLAMA_EXE) $(DIST_OLLAMA_EXE) + go clean -cache + +help: + @echo "The following make targets will help you build Ollama" + @echo "" + @echo " make all # (default target) Build Ollama llm subprocess runners, and the primary ollama executable" + @echo " make runners # Build Ollama llm subprocess runners; after you may use 'go build .' to build the primary ollama exectuable" + @echo " make # Build specific runners. Enabled: '$(RUNNER_TARGETS)'" + @echo " make dist # Build the runners and primary ollama executable for distribution" + @echo " make help-sync # Help information on vendor update targets" + @echo " make help-runners # Help information on runner targets" + @echo "" + @echo "The following make targets will help you test Ollama" + @echo "" + @echo " make test # Run unit tests" + @echo " make integration # Run integration tests. You must 'make all' first" + @echo " make lint # Run lint and style tests" + @echo "" + @echo "For more information see 'docs/development.md'" + @echo "" + + +help-runners: + @echo "The following runners will be built based on discovered GPU libraries: '$(RUNNER_TARGETS)'" + @echo "" + @echo "GPU Runner CPU Flags: '$(GPU_RUNNER_CPU_FLAGS)' (Override with CUSTOM_CPU_FLAGS)" + @echo "" + @echo "# CUDA_PATH sets the location where CUDA toolkits are present" + @echo "CUDA_PATH=$(CUDA_PATH)" + @echo " CUDA_11_PATH=$(CUDA_11_PATH)" + @echo " CUDA_11_COMPILER=$(CUDA_11_COMPILER)" + @echo " CUDA_12_PATH=$(CUDA_12_PATH)" + @echo " CUDA_12_COMPILER=$(CUDA_12_COMPILER)" + @echo "" + @echo "# HIP_PATH sets the location where the ROCm toolkit is present" + @echo "HIP_PATH=$(HIP_PATH)" + @echo " HIP_COMPILER=$(HIP_COMPILER)" + +.PHONY: all exe dist help help-sync help-runners test integration lint runners clean $(RUNNER_TARGETS) + +# Handy debugging for make variables +print-%: + @echo '$*=$($*)' diff --git a/README.md b/README.md index 466f315ad..908d31ebb 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,12 @@
ollama + ollama +
# Ollama -[![Discord](https://dcbadge.vercel.app/api/server/ollama?style=flat&compact=true)](https://discord.gg/ollama) +[Discord](https://discord.gg/ollama) Get up and running with large language models. @@ -12,7 +14,7 @@ Get up and running with large language models. [Download](https://ollama.com/download/Ollama-darwin.zip) -### Windows preview +### Windows [Download](https://ollama.com/download/OllamaSetup.exe) @@ -35,10 +37,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla ## Quickstart -To run and chat with [Llama 3.1](https://ollama.com/library/llama3.1): +To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2): ``` -ollama run llama3.1 +ollama run llama3.2 ``` ## Model library @@ -47,24 +49,28 @@ Ollama supports a list of models available on [ollama.com/library](https://ollam Here are some example models that can be downloaded: -| Model | Parameters | Size | Download | -| ------------------ | ---------- | ----- | ------------------------------ | -| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | -| Llama 3.1 | 70B | 40GB | `ollama run llama3.1:70b` | -| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | -| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | -| Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` | -| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` | -| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` | -| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` | -| Mistral | 7B | 4.1GB | `ollama run mistral` | -| Moondream 2 | 1.4B | 829MB | `ollama run moondream` | -| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | -| Starling | 7B | 4.1GB | `ollama run starling-lm` | -| Code Llama | 7B | 3.8GB | `ollama run codellama` | -| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | -| LLaVA | 7B | 4.5GB | `ollama run llava` | -| Solar | 10.7B | 6.1GB | `ollama run solar` | +| Model | Parameters | Size | Download | +| ------------------ | ---------- | ----- | -------------------------------- | +| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` | +| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` | +| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` | +| Llama 3.2 Vision | 11B | 7.9GB | `ollama run llama3.2-vision` | +| Llama 3.2 Vision | 90B | 55GB | `ollama run llama3.2-vision:90b` | +| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | +| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | +| Phi 4 | 14B | 9.1GB | `ollama run phi4` | +| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | +| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` | +| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` | +| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` | +| Mistral | 7B | 4.1GB | `ollama run mistral` | +| Moondream 2 | 1.4B | 829MB | `ollama run moondream` | +| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | +| Starling | 7B | 4.1GB | `ollama run starling-lm` | +| Code Llama | 7B | 3.8GB | `ollama run codellama` | +| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | +| LLaVA | 7B | 4.5GB | `ollama run llava` | +| Solar | 10.7B | 6.1GB | `ollama run solar` | > [!NOTE] > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. @@ -93,22 +99,22 @@ Ollama supports importing GGUF models in the Modelfile: ollama run example ``` -### Import from PyTorch or Safetensors +### Import from Safetensors See the [guide](docs/import.md) on importing models for more information. ### Customize a prompt -Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.1` model: +Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.2` model: ``` -ollama pull llama3.1 +ollama pull llama3.2 ``` Create a `Modelfile`: ``` -FROM llama3.1 +FROM llama3.2 # set the temperature to 1 [higher is more creative, lower is more coherent] PARAMETER temperature 1 @@ -143,7 +149,7 @@ ollama create mymodel -f ./Modelfile ### Pull a model ``` -ollama pull llama3.1 +ollama pull llama3.2 ``` > This command can also be used to update a local model. Only the diff will be pulled. @@ -151,13 +157,13 @@ ollama pull llama3.1 ### Remove a model ``` -ollama rm llama3.1 +ollama rm llama3.2 ``` ### Copy a model ``` -ollama cp llama3.1 my-model +ollama cp llama3.2 my-model ``` ### Multiline input @@ -181,14 +187,14 @@ The image features a yellow smiley face, which is likely the central focus of th ### Pass the prompt as an argument ``` -$ ollama run llama3.1 "Summarize this file: $(cat README.md)" +$ ollama run llama3.2 "Summarize this file: $(cat README.md)" Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications. ``` ### Show model information ``` -ollama show llama3.1 +ollama show llama3.2 ``` ### List models on your computer @@ -206,7 +212,7 @@ ollama ps ### Stop a model which is currently running ``` -ollama stop llama3.1 +ollama stop llama3.2 ``` ### Start Ollama @@ -228,7 +234,7 @@ Next, start the server: Finally, in a separate shell, run a model: ``` -./ollama run llama3.1 +./ollama run llama3.2 ``` ## REST API @@ -239,7 +245,7 @@ Ollama has a REST API for running and managing models. ``` curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt":"Why is the sky blue?" }' ``` @@ -248,7 +254,7 @@ curl http://localhost:11434/api/generate -d '{ ``` curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] @@ -294,7 +300,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm) - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat) - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats) -- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Chat with Code Repository) +- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS) +- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories) - [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases) - [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG) - [RAGFlow](https://github.com/infiniflow/ragflow) (Open-source Retrieval-Augmented Generation engine based on deep document understanding) @@ -304,11 +311,17 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Ollama RAG Chatbot](https://github.com/datvodinh/rag-chatbot.git) (Local Chat with multiple PDFs using Ollama and RAG) - [BrainSoup](https://www.nurgo-software.com/products/brainsoup) (Flexible native client with RAG & multi-agent automation) - [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) +- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama) +- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) +- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG) - [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord ) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) +- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine) +- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education) +- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application) - [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations) - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) - [AI Studio](https://github.com/MindWorkAI/AI-Studio) @@ -316,6 +329,9 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows) - [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac) - [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend) +- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac) +- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita) +- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration) - [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang) - [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery) - [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j @@ -325,12 +341,44 @@ See the [API documentation](./docs/api.md) for all endpoints. - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library) - [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama) +- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama) +- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface) +- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.) +- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux) +- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers +- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.) +- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page) +- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.) +- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) +- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar) +- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well) +- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface) +- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) +- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) +- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings) +- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder) +- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation) +- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI) +- [VT](https://github.com/vinhnx/vt.ai) (A minimal multimodal AI chat app, with dynamic conversation routing. Supports local models via Ollama) +- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama) +- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application available for Mac/Windows/Linux) +- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support) +- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow) +- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup) +- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI) + +### Cloud + +- [Google Cloud](https://cloud.google.com/run/docs/tutorials/gpu-gemma2-with-ollama) +- [Fly.io](https://fly.io/docs/python/do-more/add-ollama/) +- [Koyeb](https://www.koyeb.com/deploy/ollama) ### Terminal - [oterm](https://github.com/ggozad/oterm) - [Ellama Emacs client](https://github.com/s-kostyaev/ellama) - [Emacs client](https://github.com/zweifisch/ollama) +- [neollama](https://github.com/paradoxical-dev/neollama) UI client for interacting with models from within Neovim - [gen.nvim](https://github.com/David-Kunz/gen.nvim) - [ollama.nvim](https://github.com/nomnivore/ollama.nvim) - [ollero.nvim](https://github.com/marco-souza/ollero.nvim) @@ -340,7 +388,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Oatmeal](https://github.com/dustinblackman/oatmeal) - [cmdh](https://github.com/pgibler/cmdh) - [ooo](https://github.com/npahlfer/ooo) -- [shell-pilot](https://github.com/reid41/shell-pilot) +- [shell-pilot](https://github.com/reid41/shell-pilot)(Interact with models via pure shell scripts on Linux or macOS) - [tenere](https://github.com/pythops/tenere) - [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/). - [typechat-cli](https://github.com/anaisbetts/typechat-cli) @@ -348,17 +396,28 @@ See the [API documentation](./docs/api.md) for all endpoints. - [tlm](https://github.com/yusufcanb/tlm) - [podman-ollama](https://github.com/ericcurtin/podman-ollama) - [gollama](https://github.com/sammcj/gollama) +- [ParLlama](https://github.com/paulrobello/parllama) - [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/) - [Ollama Mixture of Experts (MOE) in 50 lines of code](https://github.com/rapidarchitect/ollama_moe) - [vim-intelligence-bridge](https://github.com/pepo-ec/vim-intelligence-bridge) Simple interaction of "Ollama" with the Vim editor +- [x-cmd ollama](https://x-cmd.com/mod/ollama) +- [bb7](https://github.com/drunkwcodes/bb7) +- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage) +- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more. +- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama +- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. ### Apple Vision Pro + - [Enchanted](https://github.com/AugustDev/enchanted) ### Database +- [pgai](https://github.com/timescale/pgai) - PostgreSQL as a vector database (Create and search embeddings from Ollama models using pgvector) + - [Get started guide](https://github.com/timescale/pgai/blob/main/docs/vectorizer-quick-start.md) - [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps) - [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama) +- [Kangaroo](https://github.com/dbkangaroo/kangaroo) (AI-powered SQL client and admin tool for popular databases) ### Package managers @@ -371,13 +430,16 @@ See the [API documentation](./docs/api.md) for all endpoints. ### Libraries -- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa) +- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/) - [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama) - [crewAI](https://github.com/crewAIInc/crewAI) +- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration) +- [Spring AI](https://github.com/spring-projects/spring-ai) with [reference](https://docs.spring.io/spring-ai/reference/api/chat/ollama-chat.html) and [example](https://github.com/tzolov/ollama-tools) - [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example) - [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java) - [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs) -- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html) +- [LLPhant](https://github.com/theodo-group/LLPhant?tab=readme-ov-file#ollama) +- [LlamaIndex](https://docs.llamaindex.ai/en/stable/examples/llm/ollama/) and [LlamaIndexTS](https://ts.llamaindex.ai/modules/llms/available_llms/ollama) - [LiteLLM](https://github.com/BerriAI/litellm) - [OllamaFarm for Go](https://github.com/presbrey/ollamafarm) - [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp) @@ -401,16 +463,26 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) - [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama) - [LlamaScript](https://github.com/Project-Llama/llamascript) +- [llm-axe](https://github.com/emirsahin1/llm-axe) (Python Toolkit for Building LLM Powered Apps) - [Gollm](https://docs.gollm.co/examples/ollama-example) +- [Gollama for Golang](https://github.com/jonathanhecl/gollama) - [Ollamaclient for Golang](https://github.com/xyproto/ollamaclient) - [High-level function abstraction in Go](https://gitlab.com/tozd/go/fun) - [Ollama PHP](https://github.com/ArdaGnsrn/ollama-php) - [Agents-Flex for Java](https://github.com/agents-flex/agents-flex) with [example](https://github.com/agents-flex/agents-flex/tree/main/agents-flex-llm/agents-flex-llm-ollama/src/test/java/com/agentsflex/llm/ollama) +- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama. +- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples) +- [Ollama for Swift](https://github.com/mattt/ollama-swift) +- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/) +- [GoLamify](https://github.com/prasad89/golamify) +- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) +- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API) ### Mobile - [Enchanted](https://github.com/AugustDev/enchanted) - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) +- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) ### Extensions & Plugins @@ -418,6 +490,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Raycast extension](https://github.com/MassimilianoPasquini97/raycast_ollama) - [Discollama](https://github.com/mxyng/discollama) (Discord bot inside the Ollama discord channel) - [Continue](https://github.com/continuedev/continue) +- [Vibe](https://github.com/thewh1teagle/vibe) (Transcribe and analyze meetings with Ollama) - [Obsidian Ollama plugin](https://github.com/hinterdupfinger/obsidian-ollama) - [Logseq Ollama plugin](https://github.com/omagdy7/ollama-logseq) - [NotesOllama](https://github.com/andersrex/notesollama) (Apple Notes Ollama plugin) @@ -440,14 +513,26 @@ See the [API documentation](./docs/api.md) for all endpoints. - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend) - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support) - [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) +- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467) - [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities. - [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server) -- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links) +- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.) +- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama) +- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.) +- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.) - [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality) - [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator) - [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator) +- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper) +- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama) +- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow) +- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language ### Supported backends - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. +### Observability + +- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. +- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production. diff --git a/api/client.go b/api/client.go index 2528fb21f..4688d4d13 100644 --- a/api/client.go +++ b/api/client.go @@ -55,7 +55,7 @@ func checkError(resp *http.Response, body []byte) error { // ClientFromEnvironment creates a new [Client] using configuration from the // environment variable OLLAMA_HOST, which points to the network host and -// port on which the ollama service is listenting. The format of this variable +// port on which the ollama service is listening. The format of this variable // is: // // ://: diff --git a/api/types.go b/api/types.go index df7bab210..f4c5b1058 100644 --- a/api/types.go +++ b/api/types.go @@ -12,7 +12,7 @@ import ( "time" ) -// StatusError is an error with and HTTP status code. +// StatusError is an error with an HTTP status code and message. type StatusError struct { StatusCode int Status string @@ -57,7 +57,7 @@ type GenerateRequest struct { Template string `json:"template"` // Context is the context parameter returned from a previous call to - // Generate call. It can be used to keep a short conversational memory. + // [Client.Generate]. It can be used to keep a short conversational memory. Context []int `json:"context,omitempty"` // Stream specifies whether the response is streaming; it is true by default. @@ -67,7 +67,7 @@ type GenerateRequest struct { Raw bool `json:"raw,omitempty"` // Format specifies the format to return a response in. - Format string `json:"format"` + Format json.RawMessage `json:"format,omitempty"` // KeepAlive controls how long the model will stay loaded in memory following // this request. @@ -90,14 +90,14 @@ type ChatRequest struct { // Messages is the messages of the chat - can be used to keep a chat memory. Messages []Message `json:"messages"` - // Stream enable streaming of returned response; true by default. + // Stream enables streaming of returned responses; true by default. Stream *bool `json:"stream,omitempty"` // Format is the format to return the response in (e.g. "json"). - Format string `json:"format"` + Format json.RawMessage `json:"format,omitempty"` // KeepAlive controls how long the model will stay loaded into memory - // followin the request. + // following the request. KeepAlive *Duration `json:"keep_alive,omitempty"` // Tools is an optional list of tools the model has access to. @@ -146,6 +146,7 @@ type ToolCall struct { } type ToolCallFunction struct { + Index int `json:"index,omitempty"` Name string `json:"name"` Arguments ToolCallFunctionArguments `json:"arguments"` } @@ -203,8 +204,8 @@ type Metrics struct { EvalDuration time.Duration `json:"eval_duration,omitempty"` } -// Options specified in [GenerateRequest], if you add a new option here add it -// to the API docs also. +// Options specified in [GenerateRequest]. If you add a new option here, also +// add it to the API docs. type Options struct { Runner @@ -215,7 +216,6 @@ type Options struct { TopK int `json:"top_k,omitempty"` TopP float32 `json:"top_p,omitempty"` MinP float32 `json:"min_p,omitempty"` - TFSZ float32 `json:"tfs_z,omitempty"` TypicalP float32 `json:"typical_p,omitempty"` RepeatLastN int `json:"repeat_last_n,omitempty"` Temperature float32 `json:"temperature,omitempty"` @@ -225,7 +225,6 @@ type Options struct { Mirostat int `json:"mirostat,omitempty"` MirostatTau float32 `json:"mirostat_tau,omitempty"` MirostatEta float32 `json:"mirostat_eta,omitempty"` - PenalizeNewline bool `json:"penalize_newline,omitempty"` Stop []string `json:"stop,omitempty"` } @@ -236,7 +235,7 @@ type Runner struct { NumGPU int `json:"num_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"` LowVRAM bool `json:"low_vram,omitempty"` - F16KV bool `json:"f16_kv,omitempty"` + F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored LogitsAll bool `json:"logits_all,omitempty"` VocabOnly bool `json:"vocab_only,omitempty"` UseMMap *bool `json:"use_mmap,omitempty"` @@ -295,17 +294,21 @@ type EmbeddingResponse struct { // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { - Model string `json:"model"` - Modelfile string `json:"modelfile"` - Stream *bool `json:"stream,omitempty"` - Quantize string `json:"quantize,omitempty"` + Model string `json:"model"` + Stream *bool `json:"stream,omitempty"` + Quantize string `json:"quantize,omitempty"` + + From string `json:"from,omitempty"` + Files map[string]string `json:"files,omitempty"` + Adapters map[string]string `json:"adapters,omitempty"` + Template string `json:"template,omitempty"` + License any `json:"license,omitempty"` + System string `json:"system,omitempty"` + Parameters map[string]any `json:"parameters,omitempty"` + Messages []Message `json:"messages,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` - - // Deprecated: set the file content with Modelfile instead - Path string `json:"path"` - // Deprecated: use Quantize instead Quantization string `json:"quantization,omitempty"` } @@ -594,7 +597,6 @@ func DefaultOptions() Options { Temperature: 0.8, TopK: 40, TopP: 0.9, - TFSZ: 1.0, TypicalP: 1.0, RepeatLastN: 64, RepeatPenalty: 1.1, @@ -603,7 +605,6 @@ func DefaultOptions() Options { Mirostat: 0, MirostatTau: 5.0, MirostatEta: 0.1, - PenalizeNewline: true, Seed: -1, Runner: Runner{ @@ -613,7 +614,6 @@ func DefaultOptions() Options { NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically NumThread: 0, // let the runtime decide LowVRAM: false, - F16KV: true, UseMLock: false, UseMMap: nil, }, diff --git a/app/lifecycle/lifecycle.go b/app/lifecycle/lifecycle.go index ab624e817..c24fe6462 100644 --- a/app/lifecycle/lifecycle.go +++ b/app/lifecycle/lifecycle.go @@ -11,10 +11,12 @@ import ( "github.com/ollama/ollama/app/store" "github.com/ollama/ollama/app/tray" + "github.com/ollama/ollama/envconfig" ) func Run() { InitLogging() + slog.Info("app config", "env", envconfig.Values()) ctx, cancel := context.WithCancel(context.Background()) var done chan int diff --git a/app/lifecycle/paths.go b/app/lifecycle/paths.go index 4d9f4c5a1..42ae8267a 100644 --- a/app/lifecycle/paths.go +++ b/app/lifecycle/paths.go @@ -36,8 +36,13 @@ func init() { ServerLogFile = filepath.Join(AppDataDir, "server.log") UpgradeLogFile = filepath.Join(AppDataDir, "upgrade.log") - // Executables are stored in APPDATA - AppDir = filepath.Join(localAppData, "Programs", "Ollama") + exe, err := os.Executable() + if err != nil { + slog.Warn("error discovering executable directory", "error", err) + AppDir = filepath.Join(localAppData, "Programs", "Ollama") + } else { + AppDir = filepath.Dir(exe) + } // Make sure we have PATH set correctly for any spawned children paths := strings.Split(os.Getenv("PATH"), ";") @@ -64,7 +69,7 @@ func init() { } // Make sure our logging dir exists - _, err := os.Stat(AppDataDir) + _, err = os.Stat(AppDataDir) if errors.Is(err, os.ErrNotExist) { if err := os.MkdirAll(AppDataDir, 0o755); err != nil { slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err)) diff --git a/app/lifecycle/server.go b/app/lifecycle/server.go index 37957399c..f7aa20264 100644 --- a/app/lifecycle/server.go +++ b/app/lifecycle/server.go @@ -18,11 +18,17 @@ func getCLIFullPath(command string) string { var cmdPath string appExe, err := os.Executable() if err == nil { + // Check both the same location as the tray app, as well as ./bin cmdPath = filepath.Join(filepath.Dir(appExe), command) _, err := os.Stat(cmdPath) if err == nil { return cmdPath } + cmdPath = filepath.Join(filepath.Dir(appExe), "bin", command) + _, err = os.Stat(cmdPath) + if err == nil { + return cmdPath + } } cmdPath, err = exec.LookPath(command) if err == nil { diff --git a/app/lifecycle/updater_windows.go b/app/lifecycle/updater_windows.go index 1d3830d4e..293dd6038 100644 --- a/app/lifecycle/updater_windows.go +++ b/app/lifecycle/updater_windows.go @@ -26,19 +26,15 @@ func DoUpgrade(cancel context.CancelFunc, done chan int) error { slog.Info("starting upgrade with " + installerExe) slog.Info("upgrade log file " + UpgradeLogFile) - // When running in debug mode, we'll be "verbose" and let the installer pop up and prompt + // make the upgrade show progress, but non interactive installArgs := []string{ "/CLOSEAPPLICATIONS", // Quit the tray app if it's still running "/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd "/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed - } - // make the upgrade as quiet as possible (no GUI, no prompts) - installArgs = append(installArgs, - "/SP", // Skip the "This will install... Do you wish to continue" prompt - "/SUPPRESSMSGBOXES", + "/SP", // Skip the "This will install... Do you wish to continue" prompt + "/NOCANCEL", // Disable the ability to cancel upgrade mid-flight to avoid partially installed upgrades "/SILENT", - "/VERYSILENT", - ) + } // Safeguard in case we have requests in flight that need to drain... slog.Info("Waiting for server to shutdown") diff --git a/app/ollama.iss b/app/ollama.iss index 63b5bdb0f..d575fc7f6 100644 --- a/app/ollama.iss +++ b/app/ollama.iss @@ -53,8 +53,8 @@ RestartIfNeededByRun=no ; https://jrsoftware.org/ishelp/index.php?topic=setup_wizardimagefile WizardSmallImageFile=.\assets\setup.bmp -; TODO verifty actual min windows version... -; OG Win 10 +; Ollama requires Windows 10 22H2 or newer for proper unicode rendering +; TODO: consider setting this to 10.0.19045 MinVersion=10.0.10240 ; First release that supports WinRT UI Composition for win32 apps @@ -97,7 +97,6 @@ Source: "..\dist\windows-amd64\lib\ollama\*"; DestDir: "{app}\lib\ollama\"; Chec Source: "..\dist\windows-arm64\vc_redist.arm64.exe"; DestDir: "{tmp}"; Check: IsArm64() and vc_redist_needed(); Flags: deleteafterinstall Source: "..\dist\windows-arm64-app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ;Check: IsArm64(); Flags: ignoreversion 64bit Source: "..\dist\windows-arm64\ollama.exe"; DestDir: "{app}"; Check: IsArm64(); Flags: ignoreversion 64bit -Source: "..\dist\windows-arm64\lib\ollama\*"; DestDir: "{app}\lib\ollama\"; Check: IsArm64(); Flags: ignoreversion 64bit recursesubdirs #endif Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion @@ -136,13 +135,13 @@ Type: filesandordirs; Name: "{%TEMP}\ollama*" Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama" [Messages] -WizardReady=Ollama Windows Preview +WizardReady=Ollama ReadyLabel1=%nLet's get you up and running with your own large language models. SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or finish the other installer, then click OK to continue with this install, or Cancel to exit. ;FinishedHeadingLabel=Run your first model -;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3.1 +;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3.2 ;ClickFinish=%n [Registry] diff --git a/app/ollama_welcome.ps1 b/app/ollama_welcome.ps1 index 46777a3a6..e96957486 100644 --- a/app/ollama_welcome.ps1 +++ b/app/ollama_welcome.ps1 @@ -4,5 +4,5 @@ write-host "Welcome to Ollama!" write-host "" write-host "Run your first model:" write-host "" -write-host "`tollama run llama3.1" +write-host "`tollama run llama3.2" write-host "" \ No newline at end of file diff --git a/app/store/store.go b/app/store/store.go index b743e8a88..370436c58 100644 --- a/app/store/store.go +++ b/app/store/store.go @@ -64,7 +64,7 @@ func initStore() { slog.Debug(fmt.Sprintf("unexpected error searching for store: %s", err)) } slog.Debug("initializing new store") - store.ID = uuid.New().String() + store.ID = uuid.NewString() writeStore(getStorePath()) } diff --git a/app/tray/wintray/eventloop.go b/app/tray/wintray/eventloop.go index 157828a36..35608a49e 100644 --- a/app/tray/wintray/eventloop.go +++ b/app/tray/wintray/eventloop.go @@ -98,7 +98,7 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui } err = t.wcex.unregister() if err != nil { - slog.Error(fmt.Sprintf("failed to uregister windo %s", err)) + slog.Error(fmt.Sprintf("failed to unregister window %s", err)) } case WM_DESTROY: // same as WM_ENDSESSION, but throws 0 exit code after all diff --git a/app/tray/wintray/menus.go b/app/tray/wintray/menus.go index 596244442..0b13d7cb6 100644 --- a/app/tray/wintray/menus.go +++ b/app/tray/wintray/menus.go @@ -11,12 +11,13 @@ import ( ) const ( - updateAvailableMenuID = 1 - updateMenuID = updateAvailableMenuID + 1 - separatorMenuID = updateMenuID + 1 - diagLogsMenuID = separatorMenuID + 1 - diagSeparatorMenuID = diagLogsMenuID + 1 - quitMenuID = diagSeparatorMenuID + 1 + _ = iota + updateAvailableMenuID + updateMenuID + separatorMenuID + diagLogsMenuID + diagSeparatorMenuID + quitMenuID ) func (t *winTray) initMenus() error { @@ -38,7 +39,7 @@ func (t *winTray) UpdateAvailable(ver string) error { if err := t.addOrUpdateMenuItem(updateAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil { return fmt.Errorf("unable to create menu entries %w", err) } - if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenutTitle, false); err != nil { + if err := t.addOrUpdateMenuItem(updateMenuID, 0, updateMenuTitle, false); err != nil { return fmt.Errorf("unable to create menu entries %w", err) } if err := t.addSeparatorMenuItem(separatorMenuID, 0); err != nil { diff --git a/app/tray/wintray/messages.go b/app/tray/wintray/messages.go index d364c7163..64a47855f 100644 --- a/app/tray/wintray/messages.go +++ b/app/tray/wintray/messages.go @@ -10,6 +10,6 @@ const ( quitMenuTitle = "Quit Ollama" updateAvailableMenuTitle = "An update is available" - updateMenutTitle = "Restart to update" + updateMenuTitle = "Restart to update" diagLogsMenuTitle = "View logs" ) diff --git a/app/tray/wintray/tray.go b/app/tray/wintray/tray.go index 6f8278939..19fa98e9f 100644 --- a/app/tray/wintray/tray.go +++ b/app/tray/wintray/tray.go @@ -361,7 +361,7 @@ func (t *winTray) showMenu() error { boolRet, _, err = pTrackPopupMenu.Call( uintptr(t.menus[0]), - TPM_BOTTOMALIGN|TPM_LEFTALIGN, + TPM_BOTTOMALIGN|TPM_LEFTALIGN|TPM_RIGHTBUTTON, uintptr(p.X), uintptr(p.Y), 0, diff --git a/app/tray/wintray/w32api.go b/app/tray/wintray/w32api.go index 7c7c0ac8a..d23bfd977 100644 --- a/app/tray/wintray/w32api.go +++ b/app/tray/wintray/w32api.go @@ -67,6 +67,7 @@ const ( SW_HIDE = 0 TPM_BOTTOMALIGN = 0x0020 TPM_LEFTALIGN = 0x0000 + TPM_RIGHTBUTTON = 0x0002 WM_CLOSE = 0x0010 WM_USER = 0x0400 WS_CAPTION = 0x00C00000 diff --git a/build/darwin/amd64/placeholder b/build/darwin/amd64/placeholder deleted file mode 100644 index 87dc27381..000000000 --- a/build/darwin/amd64/placeholder +++ /dev/null @@ -1 +0,0 @@ -This is here to make sure the build/ directory exists for the go:embed command diff --git a/build/darwin/arm64/placeholder b/build/darwin/arm64/placeholder deleted file mode 100644 index 87dc27381..000000000 --- a/build/darwin/arm64/placeholder +++ /dev/null @@ -1 +0,0 @@ -This is here to make sure the build/ directory exists for the go:embed command diff --git a/build/embed_darwin_amd64.go b/build/embed_darwin_amd64.go deleted file mode 100644 index af1458ea9..000000000 --- a/build/embed_darwin_amd64.go +++ /dev/null @@ -1,8 +0,0 @@ -package build - -import "embed" - -// Darwin payloads separated by architecture to avoid duplicate payloads when cross compiling - -//go:embed darwin/amd64/* -var EmbedFS embed.FS diff --git a/build/embed_darwin_arm64.go b/build/embed_darwin_arm64.go deleted file mode 100644 index d885365d0..000000000 --- a/build/embed_darwin_arm64.go +++ /dev/null @@ -1,8 +0,0 @@ -package build - -import "embed" - -// Darwin payloads separated by architecture to avoid duplicate payloads when cross compiling - -//go:embed darwin/arm64/* -var EmbedFS embed.FS diff --git a/build/embed_linux.go b/build/embed_linux.go deleted file mode 100644 index 4cf7be4c3..000000000 --- a/build/embed_linux.go +++ /dev/null @@ -1,6 +0,0 @@ -package build - -import "embed" - -//go:embed linux/* -var EmbedFS embed.FS diff --git a/build/embed_unused.go b/build/embed_unused.go deleted file mode 100644 index 00fbe02e8..000000000 --- a/build/embed_unused.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build !linux && !darwin - -package build - -import "embed" - -// unused on windows -var EmbedFS embed.FS diff --git a/build/linux/amd64/placeholder b/build/linux/amd64/placeholder deleted file mode 100644 index 87dc27381..000000000 --- a/build/linux/amd64/placeholder +++ /dev/null @@ -1 +0,0 @@ -This is here to make sure the build/ directory exists for the go:embed command diff --git a/build/linux/arm64/placeholder b/build/linux/arm64/placeholder deleted file mode 100644 index 87dc27381..000000000 --- a/build/linux/arm64/placeholder +++ /dev/null @@ -1 +0,0 @@ -This is here to make sure the build/ directory exists for the go:embed command diff --git a/cmd/cmd.go b/cmd/cmd.go index 3bb8b06ec..cfefa35c6 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,13 +1,11 @@ package cmd import ( - "archive/zip" "bufio" - "bytes" "context" "crypto/ed25519" "crypto/rand" - "crypto/sha256" + "encoding/json" "encoding/pem" "errors" "fmt" @@ -19,9 +17,7 @@ import ( "os" "os/signal" "path/filepath" - "regexp" "runtime" - "slices" "strconv" "strings" "sync/atomic" @@ -36,100 +32,115 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llama/runner" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/server" - "github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" ) -func CreateHandler(cmd *cobra.Command, args []string) error { +var errModelfileNotFound = errors.New("specified Modelfile wasn't found") + +func getModelfileName(cmd *cobra.Command) (string, error) { filename, _ := cmd.Flags().GetString("file") - filename, err := filepath.Abs(filename) + + if filename == "" { + filename = "Modelfile" + } + + absName, err := filepath.Abs(filename) + if err != nil { + return "", err + } + + _, err = os.Stat(absName) + if err != nil { + return filename, err + } + + return absName, nil +} + +func CreateHandler(cmd *cobra.Command, args []string) error { + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + var reader io.Reader + + filename, err := getModelfileName(cmd) + if os.IsNotExist(err) { + if filename == "" { + reader = strings.NewReader("FROM .\n") + } else { + return errModelfileNotFound + } + } else if err != nil { + return err + } else { + f, err := os.Open(filename) + if err != nil { + return err + } + + reader = f + defer f.Close() + } + + modelfile, err := parser.ParseFile(reader) if err != nil { return err } + status := "gathering model components" + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + + req, err := modelfile.CreateRequest(filepath.Dir(filename)) + if err != nil { + return err + } + spinner.Stop() + + req.Name = args[0] + quantize, _ := cmd.Flags().GetString("quantize") + if quantize != "" { + req.Quantize = quantize + } + client, err := api.ClientFromEnvironment() if err != nil { return err } - p := progress.NewProgress(os.Stderr) - defer p.Stop() - - f, err := os.Open(filename) - if err != nil { - return err - } - defer f.Close() - - modelfile, err := parser.ParseFile(f) - if err != nil { - return err - } - - home, err := os.UserHomeDir() - if err != nil { - return err - } - - status := "transferring model data" - spinner := progress.NewSpinner(status) - p.Add(status, spinner) - defer p.Stop() - - for i := range modelfile.Commands { - switch modelfile.Commands[i].Name { - case "model", "adapter": - path := modelfile.Commands[i].Args - if path == "~" { - path = home - } else if strings.HasPrefix(path, "~/") { - path = filepath.Join(home, path[2:]) - } - - if !filepath.IsAbs(path) { - path = filepath.Join(filepath.Dir(filename), path) - } - - fi, err := os.Stat(path) - if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" { - continue - } else if err != nil { + if len(req.Files) > 0 { + fileMap := map[string]string{} + for f, digest := range req.Files { + if _, err := createBlob(cmd, client, f, digest, p); err != nil { return err } - - if fi.IsDir() { - // this is likely a safetensors or pytorch directory - // TODO make this work w/ adapters - tempfile, err := tempZipFiles(path) - if err != nil { - return err - } - defer os.RemoveAll(tempfile) - - path = tempfile - } - - digest, err := createBlob(cmd, client, path, spinner) - if err != nil { - return err - } - - modelfile.Commands[i].Args = "@" + digest + fileMap[filepath.Base(f)] = digest } + req.Files = fileMap + } + + if len(req.Adapters) > 0 { + fileMap := map[string]string{} + for f, digest := range req.Adapters { + if _, err := createBlob(cmd, client, f, digest, p); err != nil { + return err + } + fileMap[filepath.Base(f)] = digest + } + req.Adapters = fileMap } bars := make(map[string]*progress.Bar) fn := func(resp api.ProgressResponse) error { if resp.Digest != "" { - spinner.Stop() - bar, ok := bars[resp.Digest] if !ok { bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed) @@ -149,145 +160,23 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return nil } - quantize, _ := cmd.Flags().GetString("quantize") - - request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantize: quantize} - if err := client.Create(cmd.Context(), &request, fn); err != nil { + if err := client.Create(cmd.Context(), req, fn); err != nil { + if strings.Contains(err.Error(), "path or Modelfile are required") { + return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client") + } return err } return nil } -func tempZipFiles(path string) (string, error) { - tempfile, err := os.CreateTemp("", "ollama-tf") +func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) { + realPath, err := filepath.EvalSymlinks(path) if err != nil { return "", err } - defer tempfile.Close() - detectContentType := func(path string) (string, error) { - f, err := os.Open(path) - if err != nil { - return "", err - } - defer f.Close() - - var b bytes.Buffer - b.Grow(512) - - if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { - return "", err - } - - contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") - return contentType, nil - } - - glob := func(pattern, contentType string) ([]string, error) { - matches, err := filepath.Glob(pattern) - if err != nil { - return nil, err - } - - for _, safetensor := range matches { - if ct, err := detectContentType(safetensor); err != nil { - return nil, err - } else if ct != contentType { - return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor) - } - } - - return matches, nil - } - - var files []string - if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 { - // safetensors files might be unresolved git lfs references; skip if they are - // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors - files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 { - // covers adapters.safetensors - files = append(files, st...) - } else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 { - // covers adapter_model.safetensors - files = append(files, st...) - } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { - // pytorch files might also be unresolved git lfs references; skip if they are - // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin - files = append(files, pt...) - } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { - // pytorch files might also be unresolved git lfs references; skip if they are - // covers consolidated.x.pth, consolidated.pth - files = append(files, pt...) - } else { - return "", errors.New("no safetensors or torch files found") - } - - // add configuration files, json files are detected as text/plain - js, err := glob(filepath.Join(path, "*.json"), "text/plain") - if err != nil { - return "", err - } - files = append(files, js...) - - // bert models require a nested config.json - // TODO(mxyng): merge this with the glob above - js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") - if err != nil { - return "", err - } - files = append(files, js...) - - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { - // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob - // tokenizer.model might be a unresolved git lfs reference; error if it is - files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { - // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) - files = append(files, tks...) - } - - zipfile := zip.NewWriter(tempfile) - defer zipfile.Close() - - for _, file := range files { - f, err := os.Open(file) - if err != nil { - return "", err - } - defer f.Close() - - fi, err := f.Stat() - if err != nil { - return "", err - } - - zfi, err := zip.FileInfoHeader(fi) - if err != nil { - return "", err - } - - zfi.Name, err = filepath.Rel(path, file) - if err != nil { - return "", err - } - - zf, err := zipfile.CreateHeader(zfi) - if err != nil { - return "", err - } - - if _, err := io.Copy(zf, f); err != nil { - return "", err - } - } - - return tempfile.Name(), nil -} - -func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) { - bin, err := os.Open(path) + bin, err := os.Open(realPath) if err != nil { return "", err } @@ -300,18 +189,11 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *pr } fileSize := fileInfo.Size() - hash := sha256.New() - if _, err := io.Copy(hash, bin); err != nil { - return "", err - } - - if _, err := bin.Seek(0, io.SeekStart); err != nil { - return "", err - } - var pw progressWriter - status := "transferring model data 0%" - spinner.SetMessage(status) + status := fmt.Sprintf("copying file %s 0%%", digest) + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + defer spinner.Stop() done := make(chan struct{}) defer close(done) @@ -322,15 +204,14 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *pr for { select { case <-ticker.C: - spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize))) + spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize))) case <-done: - spinner.SetMessage("transferring model data 100%") + spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest)) return } } }() - digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { return "", err } @@ -422,6 +303,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { if len(prompts) > 0 { interactive = false } + // Be quiet if we're redirecting to a pipe or file + if !term.IsTerminal(int(os.Stdout.Fd())) { + interactive = false + } nowrap, err := cmd.Flags().GetBool("nowordwrap") if err != nil { @@ -453,7 +338,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - opts.MultiModal = slices.Contains(info.Details.Families, "clip") + opts.MultiModal = len(info.ProjectorInfo) != 0 opts.ParentModel = info.Details.ParentModel if interactive { @@ -478,47 +363,6 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } -func errFromUnknownKey(unknownKeyErr error) error { - // find SSH public key in the error message - sshKeyPattern := `ssh-\w+ [^\s"]+` - re := regexp.MustCompile(sshKeyPattern) - matches := re.FindStringSubmatch(unknownKeyErr.Error()) - - if len(matches) > 0 { - serverPubKey := matches[0] - - localPubKey, err := auth.GetPublicKey() - if err != nil { - return unknownKeyErr - } - - if runtime.GOOS == "linux" && serverPubKey != localPubKey { - // try the ollama service public key - svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") - if err != nil { - return unknownKeyErr - } - localPubKey = strings.TrimSpace(string(svcPubKey)) - } - - // check if the returned public key matches the local public key, this prevents adding a remote key to the user's account - if serverPubKey != localPubKey { - return unknownKeyErr - } - - var msg strings.Builder - msg.WriteString(unknownKeyErr.Error()) - msg.WriteString("\n\nYour ollama key is:\n") - msg.WriteString(localPubKey) - msg.WriteString("\nAdd your key at:\n") - msg.WriteString("https://ollama.com/settings/keys") - - return errors.New(msg.String()) - } - - return unknownKeyErr -} - func PushHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -565,6 +409,8 @@ func PushHandler(cmd *cobra.Command, args []string) error { } request := api.PushRequest{Name: args[0], Insecure: insecure} + + n := model.ParseName(args[0]) if err := client.Push(cmd.Context(), &request, fn); err != nil { if spinner != nil { spinner.Stop() @@ -572,18 +418,19 @@ func PushHandler(cmd *cobra.Command, args []string) error { if strings.Contains(err.Error(), "access denied") { return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") } - host := model.ParseName(args[0]).Host - isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com") - if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost { - // the user has not added their ollama key to ollama.com - // re-throw an error with a more user-friendly message - return errFromUnknownKey(err) - } - return err } + p.Stop() spinner.Stop() + + destination := n.String() + if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") { + destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest") + } + fmt.Printf("\nYou can find your model at:\n\n") + fmt.Printf("\t%s\n", destination) + return nil } @@ -601,7 +448,7 @@ func ListHandler(cmd *cobra.Command, args []string) error { var data [][]string for _, m := range models.Models { - if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) { + if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")}) } } @@ -680,6 +527,17 @@ func DeleteHandler(cmd *cobra.Command, args []string) error { return err } + // Unload the model if it's running before deletion + opts := &runOptions{ + Model: args[0], + KeepAlive: &api.Duration{Duration: 0}, + } + if err := loadOrUnloadModel(cmd, opts); err != nil { + if !strings.Contains(err.Error(), "not found") { + return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err) + } + } + for _, name := range args { req := api.DeleteRequest{Name: name} if err := client.Delete(cmd.Context(), &req); err != nil { @@ -755,9 +613,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error { case "parameters": fmt.Println(resp.Parameters) case "system": - fmt.Println(resp.System) + fmt.Print(resp.System) case "template": - fmt.Println(resp.Template) + fmt.Print(resp.Template) } return nil @@ -1027,10 +885,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { return nil } + if opts.Format == "json" { + opts.Format = `"` + opts.Format + `"` + } + req := &api.ChatRequest{ Model: opts.Model, Messages: opts.Messages, - Format: opts.Format, + Format: json.RawMessage(opts.Format), Options: opts.Options, } @@ -1112,12 +974,16 @@ func generate(cmd *cobra.Command, opts runOptions) error { } } + if opts.Format == "json" { + opts.Format = `"` + opts.Format + `"` + } + request := api.GenerateRequest{ Model: opts.Model, Prompt: opts.Prompt, Context: generateContext, Images: opts.Images, - Format: opts.Format, + Format: json.RawMessage(opts.Format), System: opts.System, Options: opts.Options, KeepAlive: opts.KeepAlive, @@ -1273,7 +1139,7 @@ func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile) cobra.EnableCommandSorting = false - if runtime.GOOS == "windows" { + if runtime.GOOS == "windows" && term.IsTerminal(int(os.Stdout.Fd())) { console.ConsoleFromFile(os.Stdin) //nolint:errcheck } @@ -1305,7 +1171,7 @@ func NewCLI() *cobra.Command { RunE: CreateHandler, } - createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile") + createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)") showCmd := &cobra.Command{ @@ -1403,6 +1269,19 @@ func NewCLI() *cobra.Command { RunE: DeleteHandler, } + runnerCmd := &cobra.Command{ + Use: "runner", + Short: llama.PrintSystemInfo(), + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + return runner.Execute(os.Args[1:]) + }, + FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true}, + } + runnerCmd.SetHelpFunc(func(cmd *cobra.Command, args []string) { + _ = runner.Execute(args[1:]) + }) + envVars := envconfig.AsMap() envs := []envconfig.EnvVar{envVars["OLLAMA_HOST"]} @@ -1437,6 +1316,7 @@ func NewCLI() *cobra.Command { envVars["OLLAMA_SCHED_SPREAD"], envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_FLASH_ATTENTION"], + envVars["OLLAMA_KV_CACHE_TYPE"], envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_GPU_OVERHEAD"], envVars["OLLAMA_LOAD_TIMEOUT"], @@ -1458,6 +1338,7 @@ func NewCLI() *cobra.Command { psCmd, copyCmd, deleteCmd, + runnerCmd, ) return rootCmd diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 0f8863cc7..069428bec 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -2,11 +2,17 @@ package cmd import ( "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" "os" - "path/filepath" + "strings" "testing" "github.com/google/go-cmp/cmp" + "github.com/spf13/cobra" "github.com/ollama/ollama/api" ) @@ -173,18 +179,14 @@ Weigh anchor! t.Run("license", func(t *testing.T) { var b bytes.Buffer - license, err := os.ReadFile(filepath.Join("..", "LICENSE")) - if err != nil { - t.Fatal(err) - } - + license := "MIT License\nCopyright (c) Ollama\n" if err := showInfo(&api.ShowResponse{ Details: api.ModelDetails{ Family: "test", ParameterSize: "7B", QuantizationLevel: "FP16", }, - License: string(license), + License: license, }, &b); err != nil { t.Fatal(err) } @@ -204,3 +206,413 @@ Weigh anchor! } }) } + +func TestDeleteHandler(t *testing.T) { + stopped := false + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/delete" && r.Method == http.MethodDelete { + var req api.DeleteRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if req.Name == "test-model" { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusNotFound) + } + return + } + if r.URL.Path == "/api/generate" && r.Method == http.MethodPost { + var req api.GenerateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if req.Model == "test-model" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{ + Done: true, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + stopped = true + return + } else { + w.WriteHeader(http.StatusNotFound) + if err := json.NewEncoder(w).Encode(api.GenerateResponse{ + Done: false, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } + } + })) + + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + + cmd := &cobra.Command{} + cmd.SetContext(context.TODO()) + if err := DeleteHandler(cmd, []string{"test-model"}); err != nil { + t.Fatalf("DeleteHandler failed: %v", err) + } + if !stopped { + t.Fatal("Model was not stopped before deletion") + } + + err := DeleteHandler(cmd, []string{"test-model-not-found"}) + if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") { + t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) + } +} + +func TestGetModelfileName(t *testing.T) { + tests := []struct { + name string + modelfileName string + fileExists bool + expectedName string + expectedErr error + }{ + { + name: "no modelfile specified, no modelfile exists", + modelfileName: "", + fileExists: false, + expectedName: "Modelfile", + expectedErr: os.ErrNotExist, + }, + { + name: "no modelfile specified, modelfile exists", + modelfileName: "", + fileExists: true, + expectedName: "Modelfile", + expectedErr: nil, + }, + { + name: "modelfile specified, no modelfile exists", + modelfileName: "crazyfile", + fileExists: false, + expectedName: "crazyfile", + expectedErr: os.ErrNotExist, + }, + { + name: "modelfile specified, modelfile exists", + modelfileName: "anotherfile", + fileExists: true, + expectedName: "anotherfile", + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := &cobra.Command{ + Use: "fakecmd", + } + cmd.Flags().String("file", "", "path to modelfile") + + var expectedFilename string + + if tt.fileExists { + tempDir, err := os.MkdirTemp("", "modelfiledir") + defer os.RemoveAll(tempDir) + if err != nil { + t.Fatalf("temp modelfile dir creation failed: %v", err) + } + var fn string + if tt.modelfileName != "" { + fn = tt.modelfileName + } else { + fn = "Modelfile" + } + + tempFile, err := os.CreateTemp(tempDir, fn) + if err != nil { + t.Fatalf("temp modelfile creation failed: %v", err) + } + + expectedFilename = tempFile.Name() + err = cmd.Flags().Set("file", expectedFilename) + if err != nil { + t.Fatalf("couldn't set file flag: %v", err) + } + } else { + expectedFilename = tt.expectedName + if tt.modelfileName != "" { + err := cmd.Flags().Set("file", tt.modelfileName) + if err != nil { + t.Fatalf("couldn't set file flag: %v", err) + } + } + } + + actualFilename, actualErr := getModelfileName(cmd) + + if actualFilename != expectedFilename { + t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename) + } + + if tt.expectedErr != os.ErrNotExist { + if actualErr != tt.expectedErr { + t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) + } + } else { + if !os.IsNotExist(actualErr) { + t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) + } + } + }) + } +} + +func TestPushHandler(t *testing.T) { + tests := []struct { + name string + modelName string + serverResponse map[string]func(w http.ResponseWriter, r *http.Request) + expectedError string + expectedOutput string + }{ + { + name: "successful push", + modelName: "test-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/push": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + + var req api.PushRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if req.Name != "test-model" { + t.Errorf("expected model name 'test-model', got %s", req.Name) + } + + // Simulate progress updates + responses := []api.ProgressResponse{ + {Status: "preparing manifest"}, + {Digest: "sha256:abc123456789", Total: 100, Completed: 50}, + {Digest: "sha256:abc123456789", Total: 100, Completed: 100}, + } + + for _, resp := range responses { + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.(http.Flusher).Flush() + } + }, + }, + expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", + }, + { + name: "unauthorized push", + modelName: "unauthorized-model", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/push": func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + err := json.NewEncoder(w).Encode(map[string]string{ + "error": "access denied", + }) + if err != nil { + t.Fatal(err) + } + }, + }, + expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if handler, ok := tt.serverResponse[r.URL.Path]; ok { + handler(w, r) + return + } + http.Error(w, "not found", http.StatusNotFound) + })) + defer mockServer.Close() + + t.Setenv("OLLAMA_HOST", mockServer.URL) + + cmd := &cobra.Command{} + cmd.Flags().Bool("insecure", false, "") + cmd.SetContext(context.TODO()) + + // Redirect stderr to capture progress output + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + // Capture stdout for the "Model pushed" message + oldStdout := os.Stdout + outR, outW, _ := os.Pipe() + os.Stdout = outW + + err := PushHandler(cmd, []string{tt.modelName}) + + // Restore stderr + w.Close() + os.Stderr = oldStderr + // drain the pipe + if _, err := io.ReadAll(r); err != nil { + t.Fatal(err) + } + + // Restore stdout and get output + outW.Close() + os.Stdout = oldStdout + stdout, _ := io.ReadAll(outR) + + if tt.expectedError == "" { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if tt.expectedOutput != "" { + if got := string(stdout); got != tt.expectedOutput { + t.Errorf("expected output %q, got %q", tt.expectedOutput, got) + } + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.expectedError) { + t.Errorf("expected error containing %q, got %v", tt.expectedError, err) + } + } + }) + } +} + +func TestCreateHandler(t *testing.T) { + tests := []struct { + name string + modelName string + modelFile string + serverResponse map[string]func(w http.ResponseWriter, r *http.Request) + expectedError string + expectedOutput string + }{ + { + name: "successful create", + modelName: "test-model", + modelFile: "FROM foo", + serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ + "/api/create": func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST request, got %s", r.Method) + } + + req := api.CreateRequest{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if req.Name != "test-model" { + t.Errorf("expected model name 'test-model', got %s", req.Name) + } + + if req.From != "foo" { + t.Errorf("expected from 'foo', got %s", req.From) + } + + responses := []api.ProgressResponse{ + {Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}, + {Status: "writing manifest"}, + {Status: "success"}, + } + + for _, resp := range responses { + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.(http.Flusher).Flush() + } + }, + }, + expectedOutput: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler, ok := tt.serverResponse[r.URL.Path] + if !ok { + t.Errorf("unexpected request to %s", r.URL.Path) + http.Error(w, "not found", http.StatusNotFound) + return + } + handler(w, r) + })) + t.Setenv("OLLAMA_HOST", mockServer.URL) + t.Cleanup(mockServer.Close) + tempFile, err := os.CreateTemp("", "modelfile") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + if _, err := tempFile.WriteString(tt.modelFile); err != nil { + t.Fatal(err) + } + if err := tempFile.Close(); err != nil { + t.Fatal(err) + } + + cmd := &cobra.Command{} + cmd.Flags().String("file", "", "") + if err := cmd.Flags().Set("file", tempFile.Name()); err != nil { + t.Fatal(err) + } + + cmd.Flags().Bool("insecure", false, "") + cmd.SetContext(context.TODO()) + + // Redirect stderr to capture progress output + oldStderr := os.Stderr + r, w, _ := os.Pipe() + os.Stderr = w + + // Capture stdout for the "Model pushed" message + oldStdout := os.Stdout + outR, outW, _ := os.Pipe() + os.Stdout = outW + + err = CreateHandler(cmd, []string{tt.modelName}) + + // Restore stderr + w.Close() + os.Stderr = oldStderr + // drain the pipe + if _, err := io.ReadAll(r); err != nil { + t.Fatal(err) + } + + // Restore stdout and get output + outW.Close() + os.Stdout = oldStdout + stdout, _ := io.ReadAll(outR) + + if tt.expectedError == "" { + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + if tt.expectedOutput != "" { + if got := string(stdout); got != tt.expectedOutput { + t.Errorf("expected output %q, got %q", tt.expectedOutput, got) + } + } + } + }) + } +} diff --git a/cmd/interactive.go b/cmd/interactive.go index 94578f11b..af58f4fc4 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -13,11 +13,9 @@ import ( "strings" "github.com/spf13/cobra" - "golang.org/x/exp/maps" "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/errtypes" ) @@ -213,10 +211,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } - req := &api.CreateRequest{ - Name: args[1], - Modelfile: buildModelfile(opts), - } + req := NewCreateRequest(args[1], opts) fn := func(resp api.ProgressResponse) error { return nil } err = client.Create(cmd.Context(), req, fn) if err != nil { @@ -319,8 +314,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { opts.Messages = append(opts.Messages, newMessage) } fmt.Println("Set system message.") - sb.Reset() - sb.Reset() continue default: @@ -442,13 +435,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { return err } - // clear all previous images for better responses - if len(images) > 0 { - for i := range opts.Messages { - opts.Messages[i].Images = nil - } - } - newMessage.Content = msg newMessage.Images = images } @@ -468,68 +454,51 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } } -func buildModelfile(opts runOptions) string { - var f parser.File - f.Commands = append(f.Commands, parser.Command{Name: "model", Args: cmp.Or(opts.ParentModel, opts.Model)}) +func NewCreateRequest(name string, opts runOptions) *api.CreateRequest { + req := &api.CreateRequest{ + Name: name, + From: cmp.Or(opts.ParentModel, opts.Model), + } if opts.System != "" { - f.Commands = append(f.Commands, parser.Command{Name: "system", Args: opts.System}) + req.System = opts.System } - keys := maps.Keys(opts.Options) - slices.Sort(keys) - for _, k := range keys { - v := opts.Options[k] - var cmds []parser.Command - switch t := v.(type) { - case []string: - for _, s := range t { - cmds = append(cmds, parser.Command{Name: k, Args: s}) - } - default: - cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", t)}) - } - - f.Commands = append(f.Commands, cmds...) + if len(opts.Options) > 0 { + req.Parameters = opts.Options } - for _, msg := range opts.Messages { - f.Commands = append(f.Commands, parser.Command{Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content)}) + if len(opts.Messages) > 0 { + req.Messages = opts.Messages } - return f.String() + return req } func normalizeFilePath(fp string) string { - // Define a map of escaped characters and their replacements - replacements := map[string]string{ - "\\ ": " ", // Escaped space - "\\(": "(", // Escaped left parenthesis - "\\)": ")", // Escaped right parenthesis - "\\[": "[", // Escaped left square bracket - "\\]": "]", // Escaped right square bracket - "\\{": "{", // Escaped left curly brace - "\\}": "}", // Escaped right curly brace - "\\$": "$", // Escaped dollar sign - "\\&": "&", // Escaped ampersand - "\\;": ";", // Escaped semicolon - "\\'": "'", // Escaped single quote - "\\\\": "\\", // Escaped backslash - "\\*": "*", // Escaped asterisk - "\\?": "?", // Escaped question mark - } - - for escaped, actual := range replacements { - fp = strings.ReplaceAll(fp, escaped, actual) - } - return fp + return strings.NewReplacer( + "\\ ", " ", // Escaped space + "\\(", "(", // Escaped left parenthesis + "\\)", ")", // Escaped right parenthesis + "\\[", "[", // Escaped left square bracket + "\\]", "]", // Escaped right square bracket + "\\{", "{", // Escaped left curly brace + "\\}", "}", // Escaped right curly brace + "\\$", "$", // Escaped dollar sign + "\\&", "&", // Escaped ampersand + "\\;", ";", // Escaped semicolon + "\\'", "'", // Escaped single quote + "\\\\", "\\", // Escaped backslash + "\\*", "*", // Escaped asterisk + "\\?", "?", // Escaped question mark + ).Replace(fp) } func extractFileNames(input string) []string { // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // and followed by more characters and a file extension // This will capture non filename strings, but we'll check for file existence to remove mismatches - regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b` + regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b` re := regexp.MustCompile(regexPattern) return re.FindAllString(input, -1) @@ -542,10 +511,9 @@ func extractFileData(input string) (string, []api.ImageData, error) { for _, fp := range filePaths { nfp := normalizeFilePath(fp) data, err := getImageData(nfp) - if err != nil { - if os.IsNotExist(err) { - continue - } + if errors.Is(err, os.ErrNotExist) { + continue + } else if err != nil { fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err) return "", imgs, err } @@ -553,7 +521,7 @@ func extractFileData(input string) (string, []api.ImageData, error) { input = strings.ReplaceAll(input, fp, "") imgs = append(imgs, data) } - return input, imgs, nil + return strings.TrimSpace(input), imgs, nil } func getImageData(filePath string) ([]byte, error) { diff --git a/cmd/interactive_test.go b/cmd/interactive_test.go index bb7e0abaf..3f60448b6 100644 --- a/cmd/interactive_test.go +++ b/cmd/interactive_test.go @@ -3,105 +3,50 @@ package cmd import ( "testing" - "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" - - "github.com/ollama/ollama/api" ) func TestExtractFilenames(t *testing.T) { // Unix style paths input := ` some preamble - ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 -/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg` + ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg +/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG` res := extractFileNames(input) assert.Len(t, res, 5) assert.Contains(t, res[0], "one.png") assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[3], "four.png") - assert.Contains(t, res[4], "five.svg") + assert.Contains(t, res[4], "five.JPG") assert.NotContains(t, res[4], '"') - assert.NotContains(t, res, "inbtween") + assert.NotContains(t, res, "inbetween1") + assert.NotContains(t, res, "./1.svg") // Windows style paths input = ` some preamble c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2 /absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4 -./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6 -d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8 - d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending +./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6 +d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8 + d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending ` res = extractFileNames(input) assert.Len(t, res, 10) - assert.NotContains(t, res, "inbtween") + assert.NotContains(t, res, "inbetween2") assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "c:") assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "c:") assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[3], "four.png") - assert.Contains(t, res[4], "five.svg") + assert.Contains(t, res[4], "five.JPG") assert.Contains(t, res[5], "six.png") - assert.Contains(t, res[6], "seven.svg") + assert.Contains(t, res[6], "seven.JPEG") assert.Contains(t, res[6], "d:") assert.Contains(t, res[7], "eight.png") assert.Contains(t, res[7], "c:") assert.Contains(t, res[8], "nine.png") assert.Contains(t, res[8], "d:") - assert.Contains(t, res[9], "ten.svg") + assert.Contains(t, res[9], "ten.PNG") assert.Contains(t, res[9], "E:") } - -func TestModelfileBuilder(t *testing.T) { - opts := runOptions{ - Model: "hork", - System: "You are part horse and part shark, but all hork. Do horklike things", - Messages: []api.Message{ - {Role: "user", Content: "Hey there hork!"}, - {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, - }, - Options: map[string]any{ - "temperature": 0.9, - "seed": 42, - "penalize_newline": false, - "stop": []string{"hi", "there"}, - }, - } - - t.Run("model", func(t *testing.T) { - expect := `FROM hork -SYSTEM You are part horse and part shark, but all hork. Do horklike things -PARAMETER penalize_newline false -PARAMETER seed 42 -PARAMETER stop hi -PARAMETER stop there -PARAMETER temperature 0.9 -MESSAGE user Hey there hork! -MESSAGE assistant Yes it is true, I am half horse, half shark. -` - - actual := buildModelfile(opts) - if diff := cmp.Diff(expect, actual); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } - }) - - t.Run("parent model", func(t *testing.T) { - opts.ParentModel = "horseshark" - expect := `FROM horseshark -SYSTEM You are part horse and part shark, but all hork. Do horklike things -PARAMETER penalize_newline false -PARAMETER seed 42 -PARAMETER stop hi -PARAMETER stop there -PARAMETER temperature 0.9 -MESSAGE user Hey there hork! -MESSAGE assistant Yes it is true, I am half horse, half shark. -` - actual := buildModelfile(opts) - if diff := cmp.Diff(expect, actual); diff != "" { - t.Errorf("mismatch (-want +got):\n%s", diff) - } - }) -} diff --git a/cmd/runner/main.go b/cmd/runner/main.go new file mode 100644 index 000000000..34b0e9d21 --- /dev/null +++ b/cmd/runner/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + "os" + + "github.com/ollama/ollama/llama/runner" +) + +func main() { + if err := runner.Execute(os.Args[1:]); err != nil { + fmt.Fprintf(os.Stderr, "error: %s\n", err) + os.Exit(1) + } +} diff --git a/convert/convert_test.go b/convert/convert_test.go index 2969673d5..48a2b1d45 100644 --- a/convert/convert_test.go +++ b/convert/convert_test.go @@ -29,7 +29,7 @@ type tensorData struct { Shape []int `json:"shape"` } -func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) { +func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, *llm.Tensors) { t.Helper() f, err := os.CreateTemp(t.TempDir(), "f16") @@ -60,7 +60,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) { return r, m.KV(), m.Tensors() } -func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors llm.Tensors) map[string]string { +func generateResultsJSON(t *testing.T, f *os.File, kv llm.KV, tensors *llm.Tensors) map[string]string { actual := make(map[string]string) for k, v := range kv { if s, ok := v.(json.Marshaler); !ok { diff --git a/convert/sentencepiece/sentencepiece_model.pb.go b/convert/sentencepiece/sentencepiece_model.pb.go index 5c8db9bce..6bf668911 100644 --- a/convert/sentencepiece/sentencepiece_model.pb.go +++ b/convert/sentencepiece/sentencepiece_model.pb.go @@ -331,7 +331,7 @@ type TrainerSpec struct { // Reserved special meta tokens. // * -1 is not used. // * unk_id must not be -1. - // Id must starts with 0 and be contigous. + // Id must start with 0 and be contiguous. UnkId *int32 `protobuf:"varint,40,opt,name=unk_id,json=unkId,def=0" json:"unk_id,omitempty"` // BosId *int32 `protobuf:"varint,41,opt,name=bos_id,json=bosId,def=1" json:"bos_id,omitempty"` // EosId *int32 `protobuf:"varint,42,opt,name=eos_id,json=eosId,def=2" json:"eos_id,omitempty"` // diff --git a/convert/sentencepiece_model.proto b/convert/sentencepiece_model.proto index 5dc02d6cf..370887a4a 100644 --- a/convert/sentencepiece_model.proto +++ b/convert/sentencepiece_model.proto @@ -213,7 +213,7 @@ message TrainerSpec { // Reserved special meta tokens. // * -1 is not used. // * unk_id must not be -1. - // Id must starts with 0 and be contigous. + // Id must start with 0 and be contiguous. optional int32 unk_id = 40 [default = 0]; // optional int32 bos_id = 41 [default = 1]; // optional int32 eos_id = 42 [default = 2]; // diff --git a/convert/tokenizer.go b/convert/tokenizer.go index 14d6ba66c..e7be8e402 100644 --- a/convert/tokenizer.go +++ b/convert/tokenizer.go @@ -10,6 +10,7 @@ import ( "log/slog" "os" "slices" + "strings" "golang.org/x/exp/maps" ) @@ -60,7 +61,25 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) addedTokens[t.Content] = t } - t.Merges = tt.Model.Merges + if len(tt.Model.Merges) == 0 { + // noop; merges is empty + } else if err := json.Unmarshal(tt.Model.Merges, &t.Merges); err == nil { + // noop; merges is []string + } else if merges, err := func() ([][]string, error) { + var merges [][]string + if err := json.Unmarshal(tt.Model.Merges, &merges); err != nil { + return nil, err + } + + return merges, nil + }(); err == nil { + t.Merges = make([]string, len(merges)) + for i := range merges { + t.Merges[i] = strings.Join(merges[i], " ") + } + } else { + return nil, fmt.Errorf("could not parse tokenizer merges. expected []string or [][]string: %w", err) + } sha256sum := sha256.New() for _, pt := range tt.PreTokenizer.PreTokenizers { @@ -156,9 +175,9 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error) type tokenizer struct { AddedTokens []token `json:"added_tokens"` Model struct { - Type string `json:"type"` - Vocab map[string]int `json:"vocab"` - Merges []string `json:"merges"` + Type string `json:"type"` + Vocab map[string]int `json:"vocab"` + Merges json.RawMessage `json:"merges"` } `json:"model"` PreTokenizer struct { diff --git a/convert/tokenizer_test.go b/convert/tokenizer_test.go index d9550e095..c6ef9732f 100644 --- a/convert/tokenizer_test.go +++ b/convert/tokenizer_test.go @@ -191,6 +191,62 @@ func TestParseTokenizer(t *testing.T) { Pre: "default", }, }, + { + name: "list string merges", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "model": { + "merges": [ + "a b", + "c d", + "e f" + ] + } + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + }, + Merges: []string{ + "a b", + "c d", + "e f", + }, + Pre: "default", + }, + }, + { + name: "list list string merges", + fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{ + "tokenizer.json": strings.NewReader(`{ + "model": { + "merges": [ + [ + "a", "b" + ], + [ + "c", "d" + ], + [ + "e", "f" + ] + ] + } + }`), + }), + want: &Tokenizer{ + Vocabulary: &Vocabulary{ + Model: "gpt2", + }, + Merges: []string{ + "a b", + "c d", + "e f", + }, + Pre: "default", + }, + }, } for _, tt := range cases { diff --git a/gpu/amd_common.go b/discover/amd_common.go similarity index 83% rename from gpu/amd_common.go rename to discover/amd_common.go index 2894ac2c4..3c6308610 100644 --- a/gpu/amd_common.go +++ b/discover/amd_common.go @@ -1,6 +1,6 @@ //go:build linux || windows -package gpu +package discover import ( "errors" @@ -37,19 +37,6 @@ func GetSupportedGFX(libDir string) ([]string, error) { return ret, nil } -func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { - ids := []string{} - for _, info := range gpuInfo { - if info.Library != "rocm" { - // TODO shouldn't happen if things are wired correctly... - slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) - continue - } - ids = append(ids, info.ID) - } - return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",") -} - func commonAMDValidateLibDir() (string, error) { // Favor our bundled version diff --git a/gpu/amd_hip_windows.go b/discover/amd_hip_windows.go similarity index 97% rename from gpu/amd_hip_windows.go rename to discover/amd_hip_windows.go index 2cea28242..bf19ef064 100644 --- a/gpu/amd_hip_windows.go +++ b/discover/amd_hip_windows.go @@ -1,4 +1,4 @@ -package gpu +package discover import ( "errors" @@ -64,7 +64,7 @@ func NewHipLib() (*HipLib, error) { return hl, nil } -// The hip library only evaluates the HIP_VISIBLE_DEVICES variable at startup +// The hip library only evaluates the ROCR_VISIBLE_DEVICES variable at startup // so we have to unload/reset the library after we do our initial discovery // to make sure our updates to that variable are processed by llama.cpp func (hl *HipLib) Release() { diff --git a/gpu/amd_linux.go b/discover/amd_linux.go similarity index 80% rename from gpu/amd_linux.go rename to discover/amd_linux.go index d3f5b9fc6..ecf91056d 100644 --- a/gpu/amd_linux.go +++ b/discover/amd_linux.go @@ -1,4 +1,4 @@ -package gpu +package discover import ( "bufio" @@ -47,10 +47,11 @@ var ( ) // Gather GPU information from the amdgpu driver if any supported GPUs are detected -func AMDGetGPUInfo() []RocmGPUInfo { +// Only called once during bootstrap +func AMDGetGPUInfo() ([]RocmGPUInfo, error) { resp := []RocmGPUInfo{} if !AMDDetected() { - return resp + return resp, fmt.Errorf("AMD GPUs not detected") } // Opportunistic logging of driver version to aid in troubleshooting @@ -63,22 +64,20 @@ func AMDGetGPUInfo() []RocmGPUInfo { // Determine if the user has already pre-selected which GPUs to look at, then ignore the others var visibleDevices []string hipVD := envconfig.HipVisibleDevices() // zero based index only - rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID, but consumer cards seem to not support UUID + rocrVD := envconfig.RocrVisibleDevices() // zero based index or UUID gpuDO := envconfig.GpuDeviceOrdinal() // zero based index switch { - // TODO is this priorty order right? - case hipVD != "": - visibleDevices = strings.Split(hipVD, ",") case rocrVD != "": visibleDevices = strings.Split(rocrVD, ",") - // TODO - since we don't yet support UUIDs, consider detecting and reporting here - // all our test systems show GPU-XX indicating UUID is not supported + case hipVD != "": + visibleDevices = strings.Split(hipVD, ",") case gpuDO != "": visibleDevices = strings.Split(gpuDO, ",") } gfxOverride := envconfig.HsaOverrideGfxVersion() var supported []string + depPaths := LibraryDirs() libDir := "" // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract @@ -98,7 +97,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { } return a < b }) - cpuCount := 0 + gpuCount := 0 for _, match := range matches { slog.Debug("evaluating amdgpu node " + match) fp, err := os.Open(match) @@ -107,11 +106,6 @@ func AMDGetGPUInfo() []RocmGPUInfo { continue } defer fp.Close() - nodeID, err := strconv.Atoi(filepath.Base(filepath.Dir(match))) - if err != nil { - slog.Debug("failed to parse node ID", "error", err) - continue - } scanner := bufio.NewScanner(fp) isCPU := false @@ -185,24 +179,19 @@ func AMDGetGPUInfo() []RocmGPUInfo { // do reliably report VRAM usage. if isCPU { - cpuCount++ continue } - // CPUs are always first in the list - gpuID := nodeID - cpuCount - - // Shouldn't happen, but just in case... - if gpuID < 0 { - slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue") - return nil - } - - if int(major) < RocmComputeMin { - slog.Warn(fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch), "gpu", gpuID) + // Skip over any GPUs that are masked + if major == 0 && minor == 0 && patch == 0 { + slog.Debug("skipping gpu with gfx000") continue } + // Keep track of numeric IDs based on valid GPUs + gpuID := gpuCount + gpuCount += 1 + // Look up the memory for the current node totalMemory := uint64(0) usedMemory := uint64(0) @@ -270,19 +259,20 @@ func AMDGetGPUInfo() []RocmGPUInfo { break } - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if totalMemory < IGPUMemLimit { - slog.Info("unsupported Radeon iGPU detected skipping", "id", gpuID, "total", format.HumanBytes2(totalMemory)) - continue - } var name string // TODO - PCI ID lookup if vendor > 0 && device > 0 { name = fmt.Sprintf("%04x:%04x", vendor, device) } - slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory)) + // Favor UUIDs if available to reduce possibility of getting the numeric IDs wrong + var ID string + if uniqueID != 0 { + ID = fmt.Sprintf("GPU-%016x", uniqueID) + } else { + ID = strconv.Itoa(gpuID) + } + gpuInfo := RocmGPUInfo{ GpuInfo: GpuInfo{ Library: "rocm", @@ -290,7 +280,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { TotalMemory: totalMemory, FreeMemory: (totalMemory - usedMemory), }, - ID: strconv.Itoa(gpuID), + ID: ID, Name: name, Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch), MinimumMemory: rocmMinimumMemory, @@ -298,19 +288,54 @@ func AMDGetGPUInfo() []RocmGPUInfo { DriverMinor: driverMinor, }, usedFilepath: usedFile, + index: gpuID, } + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library + if totalMemory < IGPUMemLimit { + reason := "unsupported Radeon iGPU detected skipping" + slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory)) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + continue + } + minVer, err := strconv.Atoi(RocmComputeMajorMin) + if err != nil { + slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err) + } + if int(major) < minVer { + reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch) + slog.Warn(reason, "gpu", gpuID) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + + continue + } + + slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) + slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory)) + // If the user wants to filter to a subset of devices, filter out if we aren't a match if len(visibleDevices) > 0 { include := false for _, visible := range visibleDevices { - if visible == gpuInfo.ID { + if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) { include = true break } } if !include { - slog.Info("filtering out device per user request", "id", gpuInfo.ID, "visible_devices", visibleDevices) + reason := "filtering out device per user request" + slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + continue } } @@ -320,25 +345,42 @@ func AMDGetGPUInfo() []RocmGPUInfo { if libDir == "" { libDir, err = AMDValidateLibDir() if err != nil { - slog.Warn("unable to verify rocm library, will use cpu", "error", err) - return nil + err = fmt.Errorf("unable to verify rocm library: %w", err) + slog.Warn(err.Error()) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: err.Error(), + }) + return nil, err } + depPaths = append(depPaths, libDir) } - gpuInfo.DependencyPath = libDir + gpuInfo.DependencyPath = depPaths if gfxOverride == "" { // Only load supported list once if len(supported) == 0 { supported, err = GetSupportedGFX(libDir) if err != nil { - slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err) - return nil + err = fmt.Errorf("failed to lookup supported GFX types: %w", err) + slog.Warn(err.Error()) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: err.Error(), + }) + return nil, err } slog.Debug("rocm supported GPUs", "types", supported) } gfx := gpuInfo.Compute if !slices.Contains[[]string, string](supported, gfx) { - slog.Warn("amdgpu is not supported", "gpu", gpuInfo.ID, "gpu_type", gfx, "library", libDir, "supported_types", supported) + reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) + slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + // TODO - consider discrete markdown just for ROCM troubleshooting? slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage") continue @@ -358,13 +400,16 @@ func AMDGetGPUInfo() []RocmGPUInfo { resp = append(resp, gpuInfo) } if len(resp) == 0 { - slog.Info("no compatible amdgpu devices detected") + err := fmt.Errorf("no compatible amdgpu devices detected") + slog.Info(err.Error()) + return nil, err } if err := verifyKFDDriverAccess(); err != nil { - slog.Error("amdgpu devices detected but permission problems block access", "error", err) - return nil + err = fmt.Errorf("amdgpu devices detected but permission problems block access: %w", err) + slog.Error(err.Error()) + return nil, err } - return resp + return resp, nil } // Quick check for AMD driver so we can skip amdgpu discovery if not present @@ -476,3 +521,20 @@ func verifyKFDDriverAccess() error { fd.Close() return nil } + +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "rocm" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) + continue + } + ids = append(ids, info.ID) + } + // There are 3 potential env vars to use to select GPUs. + // ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux + // GPU_DEVICE_ORDINAL supports numeric IDs only + // HIP_VISIBLE_DEVICES supports numeric IDs only + return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",") +} diff --git a/gpu/amd_windows.go b/discover/amd_windows.go similarity index 69% rename from gpu/amd_windows.go rename to discover/amd_windows.go index ef6bf830c..9477bedc7 100644 --- a/gpu/amd_windows.go +++ b/discover/amd_windows.go @@ -1,8 +1,9 @@ -package gpu +package discover import ( "bytes" "errors" + "fmt" "log/slog" "os" "path/filepath" @@ -26,12 +27,13 @@ var ( RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob? ) -func AMDGetGPUInfo() []RocmGPUInfo { +// Only called once during bootstrap +func AMDGetGPUInfo() ([]RocmGPUInfo, error) { resp := []RocmGPUInfo{} hl, err := NewHipLib() if err != nil { slog.Debug(err.Error()) - return nil + return nil, err } defer hl.Release() @@ -41,24 +43,30 @@ func AMDGetGPUInfo() []RocmGPUInfo { slog.Debug("error looking up amd driver version", "error", err) } - // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified + // Note: the HIP library automatically handles subsetting to any *_VISIBLE_DEVICES the user specified count := hl.HipGetDeviceCount() if count == 0 { - return nil + err := fmt.Errorf("no compatible amdgpu devices detected") + slog.Info(err.Error()) + return nil, err } + depPaths := LibraryDirs() libDir, err := AMDValidateLibDir() if err != nil { - slog.Warn("unable to verify rocm library, will use cpu", "error", err) - return nil + err = fmt.Errorf("unable to verify rocm library: %w", err) + slog.Warn(err.Error()) + return nil, err } + depPaths = append(depPaths, libDir) var supported []string gfxOverride := envconfig.HsaOverrideGfxVersion() if gfxOverride == "" { supported, err = GetSupportedGFX(libDir) if err != nil { - slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err) - return nil + err = fmt.Errorf("failed to lookup supported GFX types: %w", err) + slog.Warn(err.Error()) + return nil, err } } else { slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride) @@ -87,21 +95,6 @@ func AMDGetGPUInfo() []RocmGPUInfo { slog.Debug("hip device", "id", i, "name", name, "gfx", gfx) // slog.Info(fmt.Sprintf("[%d] Integrated: %d", i, props.iGPU)) // DOESN'T REPORT CORRECTLY! Always 0 // TODO Why isn't props.iGPU accurate!? - if strings.EqualFold(name, iGPUName) { - slog.Info("unsupported Radeon iGPU detected skipping", "id", i, "name", name, "gfx", gfx) - continue - } - if gfxOverride == "" { - // Strip off Target Features when comparing - if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) { - slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) - // TODO - consider discrete markdown just for ROCM troubleshooting? - slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") - continue - } else { - slog.Debug("amdgpu is supported", "gpu", i, "gpu_type", gfx) - } - } freeMemory, totalMemory, err := hl.HipMemGetInfo() if err != nil { @@ -109,14 +102,6 @@ func AMDGetGPUInfo() []RocmGPUInfo { continue } - // iGPU detection, remove this check once we can support an iGPU variant of the rocm library - if totalMemory < IGPUMemLimit { - slog.Info("amdgpu appears to be an iGPU, skipping", "gpu", i, "total", format.HumanBytes2(totalMemory)) - continue - } - - slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) gpuInfo := RocmGPUInfo{ GpuInfo: GpuInfo{ Library: "rocm", @@ -128,7 +113,7 @@ func AMDGetGPUInfo() []RocmGPUInfo { UnreliableFreeMemory: true, ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices - DependencyPath: libDir, + DependencyPath: depPaths, MinimumMemory: rocmMinimumMemory, Name: name, Compute: gfx, @@ -138,10 +123,38 @@ func AMDGetGPUInfo() []RocmGPUInfo { index: i, } + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library + if strings.EqualFold(name, iGPUName) || totalMemory < IGPUMemLimit { + reason := "unsupported Radeon iGPU detected skipping" + slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + continue + } + + // Strip off Target Features when comparing + if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) { + reason := fmt.Sprintf("amdgpu is not supported (supported types:%s)", supported) + slog.Warn(reason, "gpu_type", gfx, "gpu", gpuInfo.ID, "library", libDir) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + // HSA_OVERRIDE_GFX_VERSION not supported on windows + continue + } else { + slog.Debug("amdgpu is supported", "gpu", i, "gpu_type", gfx) + } + + slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory)) + slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory)) + resp = append(resp, gpuInfo) } - return resp + return resp, nil } func AMDValidateLibDir() (string, error) { @@ -171,7 +184,7 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error { hl, err := NewHipLib() if err != nil { slog.Debug(err.Error()) - return nil + return err } defer hl.Release() @@ -190,3 +203,20 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error { } return nil } + +func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "rocm" { + // TODO shouldn't happen if things are wired correctly... + slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library) + continue + } + ids = append(ids, info.ID) + } + // There are 3 potential env vars to use to select GPUs. + // ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows + // HIP_VISIBLE_DEVICES supports numeric IDs only + // GPU_DEVICE_ORDINAL supports numeric IDs only + return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",") +} diff --git a/gpu/cpu_common.go b/discover/cpu_common.go similarity index 68% rename from gpu/cpu_common.go rename to discover/cpu_common.go index 34edcdc5a..242e48791 100644 --- a/gpu/cpu_common.go +++ b/discover/cpu_common.go @@ -1,25 +1,12 @@ -package gpu +package discover import ( "os" "path/filepath" "runtime" "strings" - - "golang.org/x/sys/cpu" ) -func GetCPUCapability() CPUCapability { - if cpu.X86.HasAVX2 { - return CPUCapabilityAVX2 - } - if cpu.X86.HasAVX { - return CPUCapabilityAVX - } - // else LCD - return CPUCapabilityNone -} - func IsNUMA() bool { if runtime.GOOS != "linux" { // numa support in llama.cpp is linux only diff --git a/gpu/cuda_common.go b/discover/cuda_common.go similarity index 99% rename from gpu/cuda_common.go rename to discover/cuda_common.go index aceec70af..878cee8cb 100644 --- a/gpu/cuda_common.go +++ b/discover/cuda_common.go @@ -1,6 +1,6 @@ //go:build linux || windows -package gpu +package discover import ( "log/slog" diff --git a/gpu/gpu.go b/discover/gpu.go similarity index 74% rename from gpu/gpu.go rename to discover/gpu.go index 1bd337f19..15d0b99a4 100644 --- a/gpu/gpu.go +++ b/discover/gpu.go @@ -1,6 +1,6 @@ //go:build linux || windows -package gpu +package discover /* #cgo linux LDFLAGS: -lrt -lpthread -ldl -lstdc++ -lm @@ -16,12 +16,14 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "sync" "unsafe" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/runners" ) type cudaHandles struct { @@ -50,7 +52,6 @@ const ( var ( gpuMutex sync.Mutex bootstrapped bool - cpuCapability CPUCapability cpus []CPUInfo cudaGPUs []CudaGPUInfo nvcudaLibPath string @@ -62,12 +63,23 @@ var ( rocmGPUs []RocmGPUInfo oneapiGPUs []OneapiGPUInfo vulkanGPUs []VulkanGPUInfo + + // If any discovered GPUs are incompatible, report why + unsupportedGPUs []UnsupportedGPUInfo + + // Keep track of errors during bootstrapping so that if GPUs are missing + // they expected to be present this may explain why + bootstrapErrors []error ) // With our current CUDA compile flags, older than 5.0 will not work properly -var CudaComputeMin = [2]C.int{5, 0} +// (string values used to allow ldflags overrides at build time) +var ( + CudaComputeMajorMin = "5" + CudaComputeMinorMin = "0" +) -var RocmComputeMin = 9 +var RocmComputeMajorMin = "9" // TODO find a better way to detect iGPU instead of minimum memory const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU @@ -78,16 +90,17 @@ func initCudaHandles() *cudaHandles { cHandles := &cudaHandles{} // Short Circuit if we already know which library to use + // ignore bootstrap errors in this case since we already recorded them if nvmlLibPath != "" { - cHandles.nvml, _ = LoadNVMLMgmt([]string{nvmlLibPath}) + cHandles.nvml, _, _ = loadNVMLMgmt([]string{nvmlLibPath}) return cHandles } if nvcudaLibPath != "" { - cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath}) + cHandles.deviceCount, cHandles.nvcuda, _, _ = loadNVCUDAMgmt([]string{nvcudaLibPath}) return cHandles } if cudartLibPath != "" { - cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath}) + cHandles.deviceCount, cHandles.cudart, _, _ = loadCUDARTMgmt([]string{cudartLibPath}) return cHandles } @@ -101,27 +114,30 @@ func initCudaHandles() *cudaHandles { localAppData := os.Getenv("LOCALAPPDATA") cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)} } - libDir := LibraryDir() - if libDir != "" { - cudartMgmtPatterns = []string{filepath.Join(libDir, CudartMgmtName)} + libDirs := LibraryDirs() + for _, d := range libDirs { + cudartMgmtPatterns = append(cudartMgmtPatterns, filepath.Join(d, CudartMgmtName)) } cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...) if len(NvmlGlobs) > 0 { nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs) if len(nvmlLibPaths) > 0 { - nvml, libPath := LoadNVMLMgmt(nvmlLibPaths) + nvml, libPath, err := loadNVMLMgmt(nvmlLibPaths) if nvml != nil { slog.Debug("nvidia-ml loaded", "library", libPath) cHandles.nvml = nvml nvmlLibPath = libPath } + if err != nil { + bootstrapErrors = append(bootstrapErrors, err) + } } } nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns) if len(nvcudaLibPaths) > 0 { - deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths) + deviceCount, nvcuda, libPath, err := loadNVCUDAMgmt(nvcudaLibPaths) if nvcuda != nil { slog.Debug("detected GPUs", "count", deviceCount, "library", libPath) cHandles.nvcuda = nvcuda @@ -129,11 +145,14 @@ func initCudaHandles() *cudaHandles { nvcudaLibPath = libPath return cHandles } + if err != nil { + bootstrapErrors = append(bootstrapErrors, err) + } } cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns) if len(cudartLibPaths) > 0 { - deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths) + deviceCount, cudart, libPath, err := loadCUDARTMgmt(cudartLibPaths) if cudart != nil { slog.Debug("detected GPUs", "library", libPath, "count", deviceCount) cHandles.cudart = cudart @@ -141,6 +160,9 @@ func initCudaHandles() *cudaHandles { cudartLibPath = libPath return cHandles } + if err != nil { + bootstrapErrors = append(bootstrapErrors, err) + } } return cHandles @@ -151,14 +173,19 @@ func initOneAPIHandles() *oneapiHandles { oHandles := &oneapiHandles{} // Short Circuit if we already know which library to use + // ignore bootstrap errors in this case since we already recorded them if oneapiLibPath != "" { - oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath}) + oHandles.deviceCount, oHandles.oneapi, _, _ = loadOneapiMgmt([]string{oneapiLibPath}) return oHandles } oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs) if len(oneapiLibPaths) > 0 { - oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths) + var err error + oHandles.deviceCount, oHandles.oneapi, oneapiLibPath, err = loadOneapiMgmt(oneapiLibPaths) + if err != nil { + bootstrapErrors = append(bootstrapErrors, err) + } } return oHandles @@ -231,36 +258,40 @@ func GetGPUInfo() GpuInfoList { if !bootstrapped { slog.Info("looking for compatible GPUs") + cudaComputeMajorMin, err := strconv.Atoi(CudaComputeMajorMin) + if err != nil { + slog.Error("invalid CudaComputeMajorMin setting", "value", CudaComputeMajorMin, "error", err) + } + cudaComputeMinorMin, err := strconv.Atoi(CudaComputeMinorMin) + if err != nil { + slog.Error("invalid CudaComputeMinorMin setting", "value", CudaComputeMinorMin, "error", err) + } + bootstrapErrors = []error{} needRefresh = false - cpuCapability = GetCPUCapability() var memInfo C.mem_info_t mem, err := GetCPUMem() if err != nil { slog.Warn("error looking up system memory", "error", err) } - depPath := LibraryDir() - + depPaths := LibraryDirs() + details, err := GetCPUDetails() + if err != nil { + slog.Warn("failed to lookup CPU details", "error", err) + } cpus = []CPUInfo{ { GpuInfo: GpuInfo{ memInfo: mem, Library: "cpu", - Variant: cpuCapability.String(), + Variant: runners.GetCPUCapability().String(), ID: "0", - DependencyPath: depPath, + DependencyPath: depPaths, }, + CPUs: details, }, } - // Fallback to CPU mode if we're lacking required vector extensions on x86 - if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" { - slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability, "detected", cpuCapability) - bootstrapped = true - // No need to do any GPU discovery, since we can't run on them - return GpuInfoList{cpus[0].GpuInfo} - } - // Load ALL libraries cHandles = initCudaHandles() @@ -287,10 +318,6 @@ func GetGPUInfo() GpuInfoList { C.free(unsafe.Pointer(memInfo.err)) continue } - if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) { - slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) - continue - } gpuInfo.TotalMemory = uint64(memInfo.total) gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) @@ -301,22 +328,37 @@ func GetGPUInfo() GpuInfoList { gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMinor = driverMinor variant := cudaVariant(gpuInfo) - if depPath != "" { - gpuInfo.DependencyPath = depPath + if depPaths != nil { + gpuInfo.DependencyPath = depPaths // Check for variant specific directory if variant != "" { - if _, err := os.Stat(filepath.Join(depPath, "cuda_"+variant)); err == nil { - gpuInfo.DependencyPath = filepath.Join(depPath, "cuda_"+variant) + for _, d := range depPaths { + if _, err := os.Stat(filepath.Join(d, "cuda_"+variant)); err == nil { + // Put the variant directory first in the search path to avoid runtime linking to the wrong library + gpuInfo.DependencyPath = append([]string{filepath.Join(d, "cuda_"+variant)}, gpuInfo.DependencyPath...) + break + } } } } gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.Variant = variant + if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { + unsupportedGPUs = append(unsupportedGPUs, + UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + }) + slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor)) + continue + } + // query the management library as well so we can record any skew between the two // which represents overhead on the GPU we must set aside on subsequent updates if cHandles.nvml != nil { - C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used) + uuid := C.CString(gpuInfo.ID) + defer C.free(unsafe.Pointer(uuid)) + C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) if memInfo.err != nil { slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err)) C.free(unsafe.Pointer(memInfo.err)) @@ -368,7 +410,7 @@ func GetGPUInfo() GpuInfoList { gpuInfo.FreeMemory = uint64(memInfo.free) gpuInfo.ID = C.GoString(&memInfo.gpu_id[0]) gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) - gpuInfo.DependencyPath = depPath + gpuInfo.DependencyPath = depPaths oneapiGPUs = append(oneapiGPUs, gpuInfo) } } @@ -408,11 +450,16 @@ func GetGPUInfo() GpuInfoList { } } - rocmGPUs = AMDGetGPUInfo() + rocmGPUs, err = AMDGetGPUInfo() + if err != nil { + bootstrapErrors = append(bootstrapErrors, err) + } bootstrapped = true if len(cudaGPUs) == 0 && len(rocmGPUs) == 0 && len(oneapiGPUs) == 0 && len(vulkanGPUs) == 0 { slog.Info("no compatible GPUs were discovered") } + + // TODO verify we have runners for the discovered GPUs, filter out any that aren't supported with good error messages } // For detected GPUs, load library if not loaded @@ -447,7 +494,9 @@ func GetGPUInfo() GpuInfoList { } for i, gpu := range cudaGPUs { if cHandles.nvml != nil { - C.nvml_get_free(*cHandles.nvml, C.int(gpu.index), &memInfo.free, &memInfo.total, &memInfo.used) + uuid := C.CString(gpu.ID) + defer C.free(unsafe.Pointer(uuid)) + C.nvml_get_free(*cHandles.nvml, uuid, &memInfo.free, &memInfo.total, &memInfo.used) } else if cHandles.cudart != nil { C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo) } else if cHandles.nvcuda != nil { @@ -554,7 +603,10 @@ func FindGPULibs(baseLibName string, defaultPatterns []string) []string { slog.Debug("Searching for GPU library", "name", baseLibName) // Start with our bundled libraries - patterns := []string{filepath.Join(LibraryDir(), baseLibName)} + patterns := []string{} + for _, d := range LibraryDirs() { + patterns = append(patterns, filepath.Join(d, baseLibName)) + } switch runtime.GOOS { case "windows": @@ -576,7 +628,6 @@ func FindGPULibs(baseLibName string, defaultPatterns []string) []string { patterns = append(patterns, defaultPatterns...) slog.Debug("gpu library search", "globs", patterns) for _, pattern := range patterns { - // Nvidia PhysX known to return bogus results if strings.Contains(pattern, "PhysX") { slog.Debug("skipping PhysX cuda library path", "path", pattern) @@ -612,92 +663,114 @@ func FindGPULibs(baseLibName string, defaultPatterns []string) []string { return gpuLibPaths } -func LoadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string) { +// Bootstrap the runtime library +// Returns: num devices, handle, libPath, error +func loadCUDARTMgmt(cudartLibPaths []string) (int, *C.cudart_handle_t, string, error) { var resp C.cudart_init_resp_t resp.ch.verbose = getVerboseState() + var err error for _, libPath := range cudartLibPaths { lib := C.CString(libPath) defer C.free(unsafe.Pointer(lib)) C.cudart_init(lib, &resp) if resp.err != nil { - slog.Debug("Unable to load cudart", "library", libPath, "error", C.GoString(resp.err)) + err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) + slog.Debug(err.Error()) C.free(unsafe.Pointer(resp.err)) } else { - return int(resp.num_devices), &resp.ch, libPath + err = nil + return int(resp.num_devices), &resp.ch, libPath, err } } - return 0, nil, "" + return 0, nil, "", err } -func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) { +// Bootstrap the driver library +// Returns: num devices, handle, libPath, error +func loadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string, error) { var resp C.nvcuda_init_resp_t resp.ch.verbose = getVerboseState() + var err error for _, libPath := range nvcudaLibPaths { lib := C.CString(libPath) defer C.free(unsafe.Pointer(lib)) C.nvcuda_init(lib, &resp) if resp.err != nil { // Decide what log level based on the type of error message to help users understand why - msg := C.GoString(resp.err) switch resp.cudaErr { case C.CUDA_ERROR_INSUFFICIENT_DRIVER, C.CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: - slog.Warn("version mismatch between driver and cuda driver library - reboot or upgrade may be required", "library", libPath, "error", msg) + err = fmt.Errorf("version mismatch between driver and cuda driver library - reboot or upgrade may be required: library %s", libPath) + slog.Warn(err.Error()) case C.CUDA_ERROR_NO_DEVICE: - slog.Info("no nvidia devices detected", "library", libPath) + err = fmt.Errorf("no nvidia devices detected by library %s", libPath) + slog.Info(err.Error()) case C.CUDA_ERROR_UNKNOWN: - slog.Warn("unknown error initializing cuda driver library", "library", libPath, "error", msg) - slog.Warn("see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information") + err = fmt.Errorf("unknown error initializing cuda driver library %s: %s. see https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for more information", libPath, C.GoString(resp.err)) + slog.Warn(err.Error()) default: + msg := C.GoString(resp.err) if strings.Contains(msg, "wrong ELF class") { slog.Debug("skipping 32bit library", "library", libPath) } else { - slog.Info("unable to load cuda driver library", "library", libPath, "error", msg) + err = fmt.Errorf("Unable to load cudart library %s: %s", libPath, C.GoString(resp.err)) + slog.Info(err.Error()) } } C.free(unsafe.Pointer(resp.err)) } else { - return int(resp.num_devices), &resp.ch, libPath + err = nil + return int(resp.num_devices), &resp.ch, libPath, err } } - return 0, nil, "" + return 0, nil, "", err } -func LoadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string) { +// Bootstrap the management library +// Returns: handle, libPath, error +func loadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string, error) { var resp C.nvml_init_resp_t resp.ch.verbose = getVerboseState() + var err error for _, libPath := range nvmlLibPaths { lib := C.CString(libPath) defer C.free(unsafe.Pointer(lib)) C.nvml_init(lib, &resp) if resp.err != nil { - slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err))) + err = fmt.Errorf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)) + slog.Info(err.Error()) C.free(unsafe.Pointer(resp.err)) } else { - return &resp.ch, libPath + err = nil + return &resp.ch, libPath, err } } - return nil, "" + return nil, "", err } -func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) { +// bootstrap the Intel GPU library +// Returns: num devices, handle, libPath, error +func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, error) { var resp C.oneapi_init_resp_t num_devices := 0 resp.oh.verbose = getVerboseState() + var err error for _, libPath := range oneapiLibPaths { lib := C.CString(libPath) defer C.free(unsafe.Pointer(lib)) C.oneapi_init(lib, &resp) if resp.err != nil { - slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err)) + err = fmt.Errorf("Unable to load oneAPI management library %s: %s", libPath, C.GoString(resp.err)) + slog.Debug(err.Error()) C.free(unsafe.Pointer(resp.err)) } else { + err = nil for i := range resp.oh.num_drivers { num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i))) } - return num_devices, &resp.oh, libPath + return num_devices, &resp.oh, libPath, err } } - return 0, nil, "" + return 0, nil, "", err } func LoadVulkanMgmt(vulkanLibPaths []string, capLibPaths []string) (int, *C.vk_handle_t, string, string) { @@ -753,30 +826,44 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { } } -func LibraryDir() string { - // On Windows/linux we bundle the dependencies at the same level as the executable +func LibraryDirs() []string { + // dependencies can exist wherever we found the runners (e.g. build tree for developers) and relative to the executable + // This can be simplified once we no longer carry runners as payloads + paths := []string{} appExe, err := os.Executable() if err != nil { slog.Warn("failed to lookup executable path", "error", err) + } else { + appRelative := filepath.Join(filepath.Dir(appExe), envconfig.LibRelativeToExe(), "lib", "ollama") + if _, err := os.Stat(appRelative); err == nil { + paths = append(paths, appRelative) + } } - cwd, err := os.Getwd() + rDir := runners.Locate() if err != nil { - slog.Warn("failed to lookup working directory", "error", err) + slog.Warn("unable to locate gpu dependency libraries", "error", err) + } else { + paths = append(paths, filepath.Dir(rDir)) + } + return paths +} + +func GetSystemInfo() SystemInfo { + gpus := GetGPUInfo() + gpuMutex.Lock() + defer gpuMutex.Unlock() + discoveryErrors := []string{} + for _, err := range bootstrapErrors { + discoveryErrors = append(discoveryErrors, err.Error()) + } + if len(gpus) == 1 && gpus[0].Library == "cpu" { + gpus = []GpuInfo{} + } + + return SystemInfo{ + System: cpus[0], + GPUs: gpus, + UnsupportedGPUs: unsupportedGPUs, + DiscoveryErrors: discoveryErrors, } - // Scan for any of our dependeices, and pick first match - for _, root := range []string{filepath.Dir(appExe), filepath.Join(filepath.Dir(appExe), envconfig.LibRelativeToExe()), cwd} { - libDep := filepath.Join("lib", "ollama") - if _, err := os.Stat(filepath.Join(root, libDep)); err == nil { - return filepath.Join(root, libDep) - } - // Developer mode, local build - if _, err := os.Stat(filepath.Join(root, runtime.GOOS+"-"+runtime.GOARCH, libDep)); err == nil { - return filepath.Join(root, runtime.GOOS+"-"+runtime.GOARCH, libDep) - } - if _, err := os.Stat(filepath.Join(root, "dist", runtime.GOOS+"-"+runtime.GOARCH, libDep)); err == nil { - return filepath.Join(root, "dist", runtime.GOOS+"-"+runtime.GOARCH, libDep) - } - } - slog.Warn("unable to locate gpu dependency libraries") - return "" } diff --git a/gpu/gpu_darwin.go b/discover/gpu_darwin.go similarity index 55% rename from gpu/gpu_darwin.go rename to discover/gpu_darwin.go index 417b48dfa..15f8f7996 100644 --- a/gpu/gpu_darwin.go +++ b/discover/gpu_darwin.go @@ -1,6 +1,6 @@ //go:build darwin -package gpu +package discover /* #cgo CFLAGS: -x objective-c @@ -10,9 +10,12 @@ package gpu import "C" import ( + "log/slog" "runtime" + "syscall" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/runners" ) const ( @@ -25,7 +28,7 @@ func GetGPUInfo() GpuInfoList { return []GpuInfo{ { Library: "cpu", - Variant: GetCPUCapability().String(), + Variant: runners.GetCPUCapability().String(), memInfo: mem, }, } @@ -48,7 +51,7 @@ func GetCPUInfo() GpuInfoList { return []GpuInfo{ { Library: "cpu", - Variant: GetCPUCapability().String(), + Variant: runners.GetCPUCapability().String(), memInfo: mem, }, } @@ -66,3 +69,34 @@ func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) { // No-op on darwin return "", "" } + +func GetSystemInfo() SystemInfo { + mem, _ := GetCPUMem() + query := "hw.perflevel0.physicalcpu" + perfCores, err := syscall.SysctlUint32(query) + if err != nil { + slog.Warn("failed to discover physical CPU details", "query", query, "error", err) + } + query = "hw.perflevel1.physicalcpu" + efficiencyCores, _ := syscall.SysctlUint32(query) // On x86 xeon this wont return data + + // Determine thread count + query = "hw.logicalcpu" + logicalCores, _ := syscall.SysctlUint32(query) + + return SystemInfo{ + System: CPUInfo{ + GpuInfo: GpuInfo{ + memInfo: mem, + }, + CPUs: []CPU{ + { + CoreCount: int(perfCores + efficiencyCores), + EfficiencyCoreCount: int(efficiencyCores), + ThreadCount: int(logicalCores), + }, + }, + }, + GPUs: GetGPUInfo(), + } +} diff --git a/gpu/gpu_info.h b/discover/gpu_info.h similarity index 100% rename from gpu/gpu_info.h rename to discover/gpu_info.h diff --git a/gpu/gpu_info_cudart.c b/discover/gpu_info_cudart.c similarity index 100% rename from gpu/gpu_info_cudart.c rename to discover/gpu_info_cudart.c diff --git a/gpu/gpu_info_cudart.h b/discover/gpu_info_cudart.h similarity index 100% rename from gpu/gpu_info_cudart.h rename to discover/gpu_info_cudart.h diff --git a/gpu/gpu_info_darwin.h b/discover/gpu_info_darwin.h similarity index 100% rename from gpu/gpu_info_darwin.h rename to discover/gpu_info_darwin.h diff --git a/gpu/gpu_info_darwin.m b/discover/gpu_info_darwin.m similarity index 100% rename from gpu/gpu_info_darwin.m rename to discover/gpu_info_darwin.m diff --git a/gpu/gpu_info_nvcuda.c b/discover/gpu_info_nvcuda.c similarity index 94% rename from gpu/gpu_info_nvcuda.c rename to discover/gpu_info_nvcuda.c index a1a38bfc2..466e1ac24 100644 --- a/gpu/gpu_info_nvcuda.c +++ b/discover/gpu_info_nvcuda.c @@ -4,6 +4,7 @@ #include "gpu_info_nvcuda.h" void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { + LOG(resp->ch.verbose, "initializing %s\n", nvcuda_lib_path); CUresult ret; resp->err = NULL; resp->num_devices = 0; @@ -57,8 +58,10 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { resp->cudaErr = -1; return; } + LOG(resp->ch.verbose, "dlsym: %s - %p\n", l[i].s, *l[i].p); } + LOG(resp->ch.verbose, "calling cuInit\n"); ret = (*resp->ch.cuInit)(0); if (ret != CUDA_SUCCESS) { LOG(resp->ch.verbose, "cuInit err: %d\n", ret); @@ -75,15 +78,18 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { resp->ch.driver_minor = 0; // Report driver version if we're in verbose mode, ignore errors + LOG(resp->ch.verbose, "calling cuDriverGetVersion\n"); ret = (*resp->ch.cuDriverGetVersion)(&version); if (ret != CUDA_SUCCESS) { LOG(resp->ch.verbose, "cuDriverGetVersion failed: %d\n", ret); } else { + LOG(resp->ch.verbose, "raw version 0x%x\n", version); resp->ch.driver_major = version / 1000; resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10; LOG(resp->ch.verbose, "CUDA driver version: %d.%d\n", resp->ch.driver_major, resp->ch.driver_minor); } + LOG(resp->ch.verbose, "calling cuDeviceGetCount\n"); ret = (*resp->ch.cuDeviceGetCount)(&resp->num_devices); if (ret != CUDA_SUCCESS) { LOG(resp->ch.verbose, "cuDeviceGetCount err: %d\n", ret); @@ -94,6 +100,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { resp->cudaErr = ret; return; } + LOG(resp->ch.verbose, "device count %d\n", resp->num_devices); } const int buflen = 256; diff --git a/gpu/gpu_info_nvcuda.h b/discover/gpu_info_nvcuda.h similarity index 100% rename from gpu/gpu_info_nvcuda.h rename to discover/gpu_info_nvcuda.h diff --git a/gpu/gpu_info_nvml.c b/discover/gpu_info_nvml.c similarity index 86% rename from gpu/gpu_info_nvml.c rename to discover/gpu_info_nvml.c index 11293e448..342a3aa4b 100644 --- a/gpu/gpu_info_nvml.c +++ b/discover/gpu_info_nvml.c @@ -17,7 +17,7 @@ void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { } l[] = { {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2}, {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown}, - {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex}, + {"nvmlDeviceGetHandleByUUID", (void *)&resp->ch.nvmlDeviceGetHandleByUUID}, {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo}, {NULL, NULL}, }; @@ -67,20 +67,20 @@ void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) { } -void nvml_get_free(nvml_handle_t h, int device_id, uint64_t *free, uint64_t *total, uint64_t *used) { +void nvml_get_free(nvml_handle_t h, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used) { nvmlDevice_t device; nvmlMemory_t memInfo = {0}; nvmlReturn_t ret; - ret = (*h.nvmlDeviceGetHandleByIndex)(device_id, &device); + ret = (*h.nvmlDeviceGetHandleByUUID)((const char *)(uuid), &device); if (ret != NVML_SUCCESS) { - LOG(1, "unable to get device handle %d: %d", device_id, ret); + LOG(1, "unable to get device handle %s: %d", uuid, ret); *free = 0; return; } ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo); if (ret != NVML_SUCCESS) { - LOG(1, "device memory info lookup failure %d: %d", device_id, ret); + LOG(1, "device memory info lookup failure %s: %d", uuid, ret); *free = 0; return; } diff --git a/gpu/gpu_info_nvml.h b/discover/gpu_info_nvml.h similarity index 85% rename from gpu/gpu_info_nvml.h rename to discover/gpu_info_nvml.h index a661f723b..908802337 100644 --- a/gpu/gpu_info_nvml.h +++ b/discover/gpu_info_nvml.h @@ -25,7 +25,7 @@ typedef struct nvml_handle { uint16_t verbose; nvmlReturn_t (*nvmlInit_v2)(void); nvmlReturn_t (*nvmlShutdown)(void); - nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *); + nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *); nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *); } nvml_handle_t; @@ -41,7 +41,7 @@ typedef struct nvml_compute_capability { } nvml_compute_capability_t; void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp); -void nvml_get_free(nvml_handle_t ch, int device_id, uint64_t *free, uint64_t *total, uint64_t *used); +void nvml_get_free(nvml_handle_t ch, char *uuid, uint64_t *free, uint64_t *total, uint64_t *used); void nvml_release(nvml_handle_t ch); #endif // __GPU_INFO_NVML_H__ diff --git a/gpu/gpu_info_oneapi.c b/discover/gpu_info_oneapi.c similarity index 100% rename from gpu/gpu_info_oneapi.c rename to discover/gpu_info_oneapi.c diff --git a/gpu/gpu_info_oneapi.h b/discover/gpu_info_oneapi.h similarity index 100% rename from gpu/gpu_info_oneapi.h rename to discover/gpu_info_oneapi.h diff --git a/gpu/gpu_info_vulkan.c b/discover/gpu_info_vulkan.c similarity index 100% rename from gpu/gpu_info_vulkan.c rename to discover/gpu_info_vulkan.c diff --git a/gpu/gpu_info_vulkan.h b/discover/gpu_info_vulkan.h similarity index 100% rename from gpu/gpu_info_vulkan.h rename to discover/gpu_info_vulkan.h diff --git a/discover/gpu_linux.go b/discover/gpu_linux.go new file mode 100644 index 000000000..840ea435a --- /dev/null +++ b/discover/gpu_linux.go @@ -0,0 +1,215 @@ +package discover + +import ( + "bufio" + "fmt" + "io" + "os" + "reflect" + "regexp" + "sort" + "strings" + + "github.com/ollama/ollama/format" +) + +var CudartGlobs = []string{ + "/usr/local/cuda/lib64/libcudart.so*", + "/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*", + "/usr/lib/x86_64-linux-gnu/libcudart.so*", + "/usr/lib/wsl/lib/libcudart.so*", + "/usr/lib/wsl/drivers/*/libcudart.so*", + "/opt/cuda/lib64/libcudart.so*", + "/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*", + "/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*", + "/usr/lib/aarch64-linux-gnu/libcudart.so*", + "/usr/local/cuda/lib*/libcudart.so*", + "/usr/lib*/libcudart.so*", + "/usr/local/lib*/libcudart.so*", +} + +var NvmlGlobs = []string{} + +var NvcudaGlobs = []string{ + "/usr/local/cuda*/targets/*/lib/libcuda.so*", + "/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*", + "/usr/lib/*-linux-gnu/libcuda.so*", + "/usr/lib/wsl/lib/libcuda.so*", + "/usr/lib/wsl/drivers/*/libcuda.so*", + "/opt/cuda/lib*/libcuda.so*", + "/usr/local/cuda/lib*/libcuda.so*", + "/usr/lib*/libcuda.so*", + "/usr/local/lib*/libcuda.so*", +} + +var OneapiGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*", + "/usr/lib*/libze_intel_gpu.so*", +} + +var ( + CudartMgmtName = "libcudart.so*" + NvcudaMgmtName = "libcuda.so*" + NvmlMgmtName = "" // not currently wired on linux + OneapiMgmtName = "libze_intel_gpu.so*" + VulkanMgmtName = "libvulkan.so*" + libcapMgmtName = "libcap.so*" +) + +var VulkanGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libvulkan.so*", + "/usr/lib*/libvulkan.so*", +} + +var capLinuxGlobs = []string{ + "/usr/lib/x86_64-linux-gnu/libcap.so*", + "/usr/lib*/libcap.so*", +} + +func FindLibCapLibs() []string { + return FindGPULibs(libcapMgmtName, capLinuxGlobs) +} + +func GetCPUMem() (memInfo, error) { + var mem memInfo + var total, available, free, buffers, cached, freeSwap uint64 + f, err := os.Open("/proc/meminfo") + if err != nil { + return mem, err + } + defer f.Close() + s := bufio.NewScanner(f) + for s.Scan() { + line := s.Text() + switch { + case strings.HasPrefix(line, "MemTotal:"): + _, err = fmt.Sscanf(line, "MemTotal:%d", &total) + case strings.HasPrefix(line, "MemAvailable:"): + _, err = fmt.Sscanf(line, "MemAvailable:%d", &available) + case strings.HasPrefix(line, "MemFree:"): + _, err = fmt.Sscanf(line, "MemFree:%d", &free) + case strings.HasPrefix(line, "Buffers:"): + _, err = fmt.Sscanf(line, "Buffers:%d", &buffers) + case strings.HasPrefix(line, "Cached:"): + _, err = fmt.Sscanf(line, "Cached:%d", &cached) + case strings.HasPrefix(line, "SwapFree:"): + _, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap) + default: + continue + } + if err != nil { + return mem, err + } + } + mem.TotalMemory = total * format.KibiByte + mem.FreeSwap = freeSwap * format.KibiByte + if available > 0 { + mem.FreeMemory = available * format.KibiByte + } else { + mem.FreeMemory = (free + buffers + cached) * format.KibiByte + } + return mem, nil +} + +const CpuInfoFilename = "/proc/cpuinfo" + +type linuxCpuInfo struct { + ID string `cpuinfo:"processor"` + VendorID string `cpuinfo:"vendor_id"` + ModelName string `cpuinfo:"model name"` + PhysicalID string `cpuinfo:"physical id"` + Siblings string `cpuinfo:"siblings"` + CoreID string `cpuinfo:"core id"` +} + +func GetCPUDetails() ([]CPU, error) { + file, err := os.Open(CpuInfoFilename) + if err != nil { + return nil, err + } + return linuxCPUDetails(file) +} + +func linuxCPUDetails(file io.Reader) ([]CPU, error) { + reColumns := regexp.MustCompile("\t+: ") + scanner := bufio.NewScanner(file) + cpuInfos := []linuxCpuInfo{} + cpu := &linuxCpuInfo{} + for scanner.Scan() { + line := scanner.Text() + if sl := reColumns.Split(line, 2); len(sl) > 1 { + t := reflect.TypeOf(cpu).Elem() + s := reflect.ValueOf(cpu).Elem() + for i := range t.NumField() { + field := t.Field(i) + tag := field.Tag.Get("cpuinfo") + if tag == sl[0] { + s.FieldByName(field.Name).SetString(sl[1]) + break + } + } + } else if strings.TrimSpace(line) == "" && cpu.ID != "" { + cpuInfos = append(cpuInfos, *cpu) + cpu = &linuxCpuInfo{} + } + } + if cpu.ID != "" { + cpuInfos = append(cpuInfos, *cpu) + } + + // Process the sockets/cores/threads + socketByID := map[string]*CPU{} + coreBySocket := map[string]map[string]struct{}{} + threadsByCoreBySocket := map[string]map[string]int{} + for _, c := range cpuInfos { + if _, found := socketByID[c.PhysicalID]; !found { + socketByID[c.PhysicalID] = &CPU{ + ID: c.PhysicalID, + VendorID: c.VendorID, + ModelName: c.ModelName, + } + coreBySocket[c.PhysicalID] = map[string]struct{}{} + threadsByCoreBySocket[c.PhysicalID] = map[string]int{} + } + if c.CoreID != "" { + coreBySocket[c.PhysicalID][c.PhysicalID+":"+c.CoreID] = struct{}{} + threadsByCoreBySocket[c.PhysicalID][c.PhysicalID+":"+c.CoreID]++ + } else { + coreBySocket[c.PhysicalID][c.PhysicalID+":"+c.ID] = struct{}{} + threadsByCoreBySocket[c.PhysicalID][c.PhysicalID+":"+c.ID]++ + } + } + + // Tally up the values from the tracking maps + for id, s := range socketByID { + s.CoreCount = len(coreBySocket[id]) + s.ThreadCount = 0 + for _, tc := range threadsByCoreBySocket[id] { + s.ThreadCount += tc + } + + // This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons? + efficiencyCoreCount := 0 + for _, threads := range threadsByCoreBySocket[id] { + if threads == 1 { + efficiencyCoreCount++ + } + } + if efficiencyCoreCount == s.CoreCount { + // 1:1 mapping means they're not actually efficiency cores, but regular cores + s.EfficiencyCoreCount = 0 + } else { + s.EfficiencyCoreCount = efficiencyCoreCount + } + } + keys := make([]string, 0, len(socketByID)) + result := make([]CPU, 0, len(socketByID)) + for k := range socketByID { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + result = append(result, *socketByID[k]) + } + return result, nil +} diff --git a/discover/gpu_linux_test.go b/discover/gpu_linux_test.go new file mode 100644 index 000000000..c4d64e389 --- /dev/null +++ b/discover/gpu_linux_test.go @@ -0,0 +1,2097 @@ +package discover + +import ( + "bytes" + "log/slog" + "testing" +) + +func TestLinuxCPUDetails(t *testing.T) { + type results struct { + cores int + efficiency int + threads int + } + type testCase struct { + input string + expCPUs []results + expThreadCount int + } + testCases := map[string]*testCase{ + "#5554 Docker Ollama container inside the LXC": { + input: `processor : 0 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 0 +cpu cores : 4 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 1 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 1 +cpu cores : 4 +apicid : 1 +initial apicid : 1 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 2 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 2 +cpu cores : 4 +apicid : 2 +initial apicid : 2 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 3 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 3 +cpu cores : 4 +apicid : 3 +initial apicid : 3 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 4 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 0 +cpu cores : 4 +apicid : 4 +initial apicid : 4 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 5 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 1 +cpu cores : 4 +apicid : 5 +initial apicid : 5 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 6 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 2 +cpu cores : 4 +apicid : 6 +initial apicid : 6 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: + +processor : 7 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2246.624 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 3 +cpu cores : 4 +apicid : 7 +initial apicid : 7 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean flushbyasid pausefilter pfthreshold v_vmsave_vmload vgif avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm flush_l1d arch_capabilities +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4493.24 +TLB size : 1024 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: +`, + expCPUs: []results{ + { + cores: 4, + efficiency: 0, + threads: 4, + }, + { + cores: 4, + efficiency: 0, + threads: 4, + }, + }, + expThreadCount: 8, + }, + + // Single Socket, 8 cores + "#5554 LXC direct output": { + input: `processor : 0 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.910 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 0 +cpu cores : 128 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 1 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.470 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 1 +cpu cores : 128 +apicid : 2 +initial apicid : 2 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 2 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.918 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 2 +cpu cores : 128 +apicid : 4 +initial apicid : 4 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 3 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 3 +cpu cores : 128 +apicid : 6 +initial apicid : 6 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 4 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3090.662 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 4 +cpu cores : 128 +apicid : 8 +initial apicid : 8 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 5 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3093.734 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 5 +cpu cores : 128 +apicid : 10 +initial apicid : 10 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 6 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 6 +cpu cores : 128 +apicid : 12 +initial apicid : 12 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 7 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 7 +cpu cores : 128 +apicid : 14 +initial apicid : 14 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] +`, + expCPUs: []results{ + { + cores: 8, + efficiency: 0, + threads: 8, + }, + }, + expThreadCount: 8, + }, + + // Note: this was a partial cut-and-paste missing at least some initial logical processor definitions + // Single Socket, 29 cores + "#5554 LXC docker container output": { + input: `processor : 483 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 19 +cpu cores : 128 +apicid : 295 +initial apicid : 295 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 484 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 20 +cpu cores : 128 +apicid : 297 +initial apicid : 297 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 485 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 21 +cpu cores : 128 +apicid : 299 +initial apicid : 299 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 486 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 22 +cpu cores : 128 +apicid : 301 +initial apicid : 301 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 487 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 23 +cpu cores : 128 +apicid : 303 +initial apicid : 303 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 488 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.717 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 24 +cpu cores : 128 +apicid : 305 +initial apicid : 305 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 489 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 25 +cpu cores : 128 +apicid : 307 +initial apicid : 307 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 490 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 26 +cpu cores : 128 +apicid : 309 +initial apicid : 309 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 491 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 27 +cpu cores : 128 +apicid : 311 +initial apicid : 311 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 492 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 28 +cpu cores : 128 +apicid : 313 +initial apicid : 313 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 493 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 29 +cpu cores : 128 +apicid : 315 +initial apicid : 315 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 494 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 30 +cpu cores : 128 +apicid : 317 +initial apicid : 317 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 495 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 31 +cpu cores : 128 +apicid : 319 +initial apicid : 319 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 496 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 80 +cpu cores : 128 +apicid : 417 +initial apicid : 417 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 497 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 81 +cpu cores : 128 +apicid : 419 +initial apicid : 419 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 498 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 82 +cpu cores : 128 +apicid : 421 +initial apicid : 421 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 499 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 83 +cpu cores : 128 +apicid : 423 +initial apicid : 423 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 500 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 84 +cpu cores : 128 +apicid : 425 +initial apicid : 425 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 501 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 85 +cpu cores : 128 +apicid : 427 +initial apicid : 427 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 502 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 86 +cpu cores : 128 +apicid : 429 +initial apicid : 429 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 503 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 87 +cpu cores : 128 +apicid : 431 +initial apicid : 431 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 504 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 88 +cpu cores : 128 +apicid : 433 +initial apicid : 433 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 505 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 89 +cpu cores : 128 +apicid : 435 +initial apicid : 435 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 506 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 90 +cpu cores : 128 +apicid : 437 +initial apicid : 437 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 507 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 91 +cpu cores : 128 +apicid : 439 +initial apicid : 439 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 508 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 92 +cpu cores : 128 +apicid : 441 +initial apicid : 441 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 509 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 93 +cpu cores : 128 +apicid : 443 +initial apicid : 443 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 510 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 94 +cpu cores : 128 +apicid : 445 +initial apicid : 445 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 511 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 1 +siblings : 256 +core id : 95 +cpu cores : 128 +apicid : 447 +initial apicid : 447 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] +`, + expCPUs: []results{ + { + cores: 29, + efficiency: 0, + threads: 29, + }, + }, + expThreadCount: 29, + }, + + "#5554 LXC docker output": { + input: `processor : 0 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.910 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 0 +cpu cores : 128 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 1 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.470 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 1 +cpu cores : 128 +apicid : 2 +initial apicid : 2 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 2 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3094.918 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 2 +cpu cores : 128 +apicid : 4 +initial apicid : 4 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 3 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 3 +cpu cores : 128 +apicid : 6 +initial apicid : 6 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 4 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3090.662 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 4 +cpu cores : 128 +apicid : 8 +initial apicid : 8 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 5 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 3093.734 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 5 +cpu cores : 128 +apicid : 10 +initial apicid : 10 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 6 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 6 +cpu cores : 128 +apicid : 12 +initial apicid : 12 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +processor : 7 +vendor_id : AuthenticAMD +cpu family : 25 +model : 160 +model name : AMD EPYC 9754 128-Core Processor +stepping : 2 +microcode : 0xaa00212 +cpu MHz : 2250.000 +cache size : 1024 KB +physical id : 0 +siblings : 256 +core id : 7 +cpu cores : 128 +apicid : 14 +initial apicid : 14 +fpu : yes +fpu_exception : yes +cpuid level : 16 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass srso +bogomips : 4492.85 +TLB size : 3584 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 52 bits physical, 57 bits virtual +power management: ts ttp tm hwpstate cpb eff_freq_ro [13] [14] + +`, + expCPUs: []results{ + { + cores: 8, + efficiency: 0, + threads: 8, + }, + }, + expThreadCount: 8, + }, + + // exposed as 8 sockets, each with 1 core, no hyperthreading + "#7359 VMware multi-core core VM": { + input: `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 0 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 1 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 2 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 2 +initial apicid : 2 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 2 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 4 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 4 +initial apicid : 4 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 3 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 6 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 6 +initial apicid : 6 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 4 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 8 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 8 +initial apicid : 8 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 5 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 10 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 10 +initial apicid : 10 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 6 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 12 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 12 +initial apicid : 12 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: + +processor : 7 +vendor_id : GenuineIntel +cpu family : 6 +model : 106 +model name : Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz +stepping : 6 +microcode : 0xd0003d1 +cpu MHz : 2893.202 +cache size : 24576 KB +physical id : 14 +siblings : 1 +core id : 0 +cpu cores : 1 +apicid : 14 +initial apicid : 14 +fpu : yes +fpu_exception : yes +cpuid level : 27 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid fsrm md_clear flush_l1d arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass swapgs itlb_multihit mmio_stale_data eibrs_pbrsb gds bhi +bogomips : 5786.40 +clflush size : 64 +cache_alignment : 64 +address sizes : 45 bits physical, 48 bits virtual +power management: +`, + expCPUs: []results{ + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + { + cores: 1, + efficiency: 0, + threads: 1, + }, + }, + expThreadCount: 8, + }, + + // Emulated dual socket setup, 2 sockets, 2 cores each, with hyperthreading + "#7287 HyperV 2 socket exposed to VM": { + input: `processor : 0 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 0 +cpu cores : 2 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7585.49 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 1 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 0 +cpu cores : 2 +apicid : 1 +initial apicid : 1 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7585.49 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 2 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 1 +cpu cores : 2 +apicid : 2 +initial apicid : 2 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7585.49 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 3 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 0 +siblings : 4 +core id : 1 +cpu cores : 2 +apicid : 3 +initial apicid : 3 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7585.49 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 4 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 0 +cpu cores : 2 +apicid : 4 +initial apicid : 4 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7634.51 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 5 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 0 +cpu cores : 2 +apicid : 5 +initial apicid : 5 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7634.51 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 6 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3792.747 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 1 +cpu cores : 2 +apicid : 6 +initial apicid : 6 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7634.51 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: + +processor : 7 +vendor_id : AuthenticAMD +cpu family : 23 +model : 96 +model name : AMD Ryzen 3 4100 4-Core Processor +stepping : 1 +microcode : 0xffffffff +cpu MHz : 3688.684 +cache size : 512 KB +physical id : 1 +siblings : 4 +core id : 1 +cpu cores : 2 +apicid : 7 +initial apicid : 7 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru arat umip rdpid +bugs : sysret_ss_attrs null_seg spectre_v1 spectre_v2 spec_store_bypass retbleed smt_rsb srso +bogomips : 7634.51 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management: +`, + expCPUs: []results{ + { + cores: 2, + efficiency: 0, + threads: 4, + }, + { + cores: 2, + efficiency: 0, + threads: 4, + }, + }, + expThreadCount: 4, + }, + } + for k, v := range testCases { + t.Run(k, func(t *testing.T) { + buf := bytes.NewBufferString(v.input) + cpus, err := linuxCPUDetails(buf) + if err != nil { + t.Fatal(err) + } + + slog.Info("example", "scenario", k, "cpus", cpus) + si := SystemInfo{ + System: CPUInfo{ + CPUs: cpus, + }, + } + threadCount := si.GetOptimalThreadCount() + if len(v.expCPUs) != len(cpus) { + t.Fatalf("incorrect number of sockets: expected:%v got:%v", v.expCPUs, cpus) + } + for i, c := range cpus { + if c.CoreCount != v.expCPUs[i].cores { + t.Fatalf("incorrect number of cores: expected:%v got:%v", v.expCPUs[i], c) + } + if c.EfficiencyCoreCount != v.expCPUs[i].efficiency { + t.Fatalf("incorrect number of efficiency cores: expected:%v got:%v", v.expCPUs[i], c) + } + if c.ThreadCount != v.expCPUs[i].threads { + t.Fatalf("incorrect number of threads: expected:%v got:%v", v.expCPUs[i], c) + } + } + + if threadCount != v.expThreadCount { + t.Fatalf("incorrect thread count expected:%d got:%d", v.expThreadCount, threadCount) + } + }) + } +} diff --git a/gpu/gpu_oneapi.go b/discover/gpu_oneapi.go similarity index 96% rename from gpu/gpu_oneapi.go rename to discover/gpu_oneapi.go index 9864bde51..77941f5b3 100644 --- a/gpu/gpu_oneapi.go +++ b/discover/gpu_oneapi.go @@ -1,6 +1,6 @@ //go:build linux || windows -package gpu +package discover import ( "log/slog" diff --git a/gpu/gpu_test.go b/discover/gpu_test.go similarity index 99% rename from gpu/gpu_test.go rename to discover/gpu_test.go index 13a3f5442..0c6ef7bad 100644 --- a/gpu/gpu_test.go +++ b/discover/gpu_test.go @@ -1,4 +1,4 @@ -package gpu +package discover import ( "runtime" diff --git a/discover/gpu_windows.go b/discover/gpu_windows.go new file mode 100644 index 000000000..d2a4d50c1 --- /dev/null +++ b/discover/gpu_windows.go @@ -0,0 +1,243 @@ +package discover + +import ( + "fmt" + "log/slog" + "syscall" + "unsafe" +) + +type MEMORYSTATUSEX struct { + length uint32 + MemoryLoad uint32 + TotalPhys uint64 + AvailPhys uint64 + TotalPageFile uint64 + AvailPageFile uint64 + TotalVirtual uint64 + AvailVirtual uint64 + AvailExtendedVirtual uint64 +} + +var ( + k32 = syscall.NewLazyDLL("kernel32.dll") + globalMemoryStatusExProc = k32.NewProc("GlobalMemoryStatusEx") + sizeofMemoryStatusEx = uint32(unsafe.Sizeof(MEMORYSTATUSEX{})) + GetLogicalProcessorInformationEx = k32.NewProc("GetLogicalProcessorInformationEx") +) + +var CudartGlobs = []string{ + "c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll", +} + +var NvmlGlobs = []string{ + "c:\\Windows\\System32\\nvml.dll", +} + +var NvcudaGlobs = []string{ + "c:\\windows\\system*\\nvcuda.dll", +} + +var OneapiGlobs = []string{ + "c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll", +} + +var VulkanGlobs = []string{ + "c:\\Windows\\System32\\vulkan-1.dll", +} + +var ( + CudartMgmtName = "cudart64_*.dll" + NvcudaMgmtName = "nvcuda.dll" + NvmlMgmtName = "nvml.dll" + OneapiMgmtName = "ze_intel_gpu64.dll" + VulkanMgmtName = "vulkan-1.dll" +) + +func FindLibCapLibs() []string { + return []string{""} +} + +func GetCPUMem() (memInfo, error) { + memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx} + r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus))) + if r1 == 0 { + return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err) + } + return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil +} + +type LOGICAL_PROCESSOR_RELATIONSHIP uint32 + +const ( + RelationProcessorCore LOGICAL_PROCESSOR_RELATIONSHIP = iota + RelationNumaNode + RelationCache + RelationProcessorPackage + RelationGroup + RelationProcessorDie + RelationNumaNodeEx + RelationProcessorModule +) +const RelationAll LOGICAL_PROCESSOR_RELATIONSHIP = 0xffff + +type GROUP_AFFINITY struct { + Mask uintptr // KAFFINITY + Group uint16 + Reserved [3]uint16 +} + +type PROCESSOR_RELATIONSHIP struct { + Flags byte + EfficiencyClass byte + Reserved [20]byte + GroupCount uint16 + GroupMask [1]GROUP_AFFINITY // len GroupCount +} + +// Omitted unused structs: NUMA_NODE_RELATIONSHIP CACHE_RELATIONSHIP GROUP_RELATIONSHIP + +type SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX struct { + Relationship LOGICAL_PROCESSOR_RELATIONSHIP + Size uint32 + U [1]byte // Union len Size + // PROCESSOR_RELATIONSHIP + // NUMA_NODE_RELATIONSHIP + // CACHE_RELATIONSHIP + // GROUP_RELATIONSHIP +} + +func (group *GROUP_AFFINITY) IsMember(target *GROUP_AFFINITY) bool { + if group == nil || target == nil { + return false + } + return group.Mask&target.Mask != 0 +} + +type winPackage struct { + groups []*GROUP_AFFINITY + coreCount int // performance cores = coreCount - efficiencyCoreCount + efficiencyCoreCount int + threadCount int +} + +func (pkg *winPackage) IsMember(target *GROUP_AFFINITY) bool { + for _, group := range pkg.groups { + if group.IsMember(target) { + return true + } + } + return false +} + +func getLogicalProcessorInformationEx() ([]byte, error) { + buf := make([]byte, 1) + bufSize := len(buf) + ret, _, err := GetLogicalProcessorInformationEx.Call( + uintptr(RelationAll), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + ) + if ret != 0 { + return nil, fmt.Errorf("failed to determine size info ret:%d %w", ret, err) + } + + buf = make([]byte, bufSize) + ret, _, err = GetLogicalProcessorInformationEx.Call( + uintptr(RelationAll), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + ) + if ret == 0 { + return nil, fmt.Errorf("failed to gather processor information ret:%d buflen:%d %w", ret, bufSize, err) + } + return buf, nil +} + +func processSystemLogicalProcessorInforationList(buf []byte) []*winPackage { + var slpi *SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX + // Find all the packages first + packages := []*winPackage{} + for bufOffset := 0; bufOffset < len(buf); bufOffset += int(slpi.Size) { + slpi = (*SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)(unsafe.Pointer(&buf[bufOffset])) + if slpi.Relationship != RelationProcessorPackage { + continue + } + pr := (*PROCESSOR_RELATIONSHIP)(unsafe.Pointer(&slpi.U[0])) + pkg := &winPackage{} + ga0 := unsafe.Pointer(&pr.GroupMask[0]) + for j := range pr.GroupCount { + gm := (*GROUP_AFFINITY)(unsafe.Pointer(uintptr(ga0) + uintptr(j)*unsafe.Sizeof(GROUP_AFFINITY{}))) + pkg.groups = append(pkg.groups, gm) + } + packages = append(packages, pkg) + } + + slog.Info("packages", "count", len(packages)) + + // To identify efficiency cores we have to compare the relative values + // Larger values are "less efficient" (aka, more performant) + var maxEfficiencyClass byte + for bufOffset := 0; bufOffset < len(buf); bufOffset += int(slpi.Size) { + slpi = (*SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)(unsafe.Pointer(&buf[bufOffset])) + if slpi.Relationship != RelationProcessorCore { + continue + } + pr := (*PROCESSOR_RELATIONSHIP)(unsafe.Pointer(&slpi.U[0])) + if pr.EfficiencyClass > maxEfficiencyClass { + maxEfficiencyClass = pr.EfficiencyClass + } + } + if maxEfficiencyClass > 0 { + slog.Info("efficiency cores detected", "maxEfficiencyClass", maxEfficiencyClass) + } + + // then match up the Cores to the Packages, count up cores, threads and efficiency cores + for bufOffset := 0; bufOffset < len(buf); bufOffset += int(slpi.Size) { + slpi = (*SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)(unsafe.Pointer(&buf[bufOffset])) + if slpi.Relationship != RelationProcessorCore { + continue + } + pr := (*PROCESSOR_RELATIONSHIP)(unsafe.Pointer(&slpi.U[0])) + ga0 := unsafe.Pointer(&pr.GroupMask[0]) + for j := range pr.GroupCount { + gm := (*GROUP_AFFINITY)(unsafe.Pointer(uintptr(ga0) + uintptr(j)*unsafe.Sizeof(GROUP_AFFINITY{}))) + for _, pkg := range packages { + if pkg.IsMember(gm) { + pkg.coreCount++ + if pr.Flags == 0 { + pkg.threadCount++ + } else { + pkg.threadCount += 2 + } + if pr.EfficiencyClass < maxEfficiencyClass { + pkg.efficiencyCoreCount++ + } + } + } + } + } + + // Summarize the results + for i, pkg := range packages { + slog.Info("", "package", i, "cores", pkg.coreCount, "efficiency", pkg.efficiencyCoreCount, "threads", pkg.threadCount) + } + + return packages +} + +func GetCPUDetails() ([]CPU, error) { + buf, err := getLogicalProcessorInformationEx() + if err != nil { + return nil, err + } + packages := processSystemLogicalProcessorInforationList(buf) + cpus := make([]CPU, len(packages)) + + for i, pkg := range packages { + cpus[i].CoreCount = pkg.coreCount + cpus[i].EfficiencyCoreCount = pkg.efficiencyCoreCount + cpus[i].ThreadCount = pkg.threadCount + } + return cpus, nil +} diff --git a/discover/gpu_windows_test.go b/discover/gpu_windows_test.go new file mode 100644 index 000000000..c4daa7b72 --- /dev/null +++ b/discover/gpu_windows_test.go @@ -0,0 +1,77 @@ +package discover + +import "testing" + +func TestProcessSystemLogicalProcessorInforationList(t *testing.T) { + type pkgs struct { + cores int + efficiency int + threads int + } + type testCase struct { + input []byte + expected []pkgs + } + testCases := map[string]*testCase{ + "AMD64 Family 25 Model 97 Stepping 2 ": { + input: []byte{ + 0x3, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x3, 0x10, 0x40, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x3, 0x10, 0x40, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x8, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x50, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x20, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, + }, + expected: []pkgs{ + { + cores: 16, + efficiency: 0, + threads: 32, + }, + }, + }, + "Intel64 Family 6 Model 183 Stepping 1": { + input: []byte{ + 0x3, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x3, 0xc, 0x40, 0x0, 0x0, 0x0, 0xe0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0xc, 0x40, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x20, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x50, 0x0, 0x0, 0x0, 0x1, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x18, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, + }, + expected: []pkgs{ + { + cores: 16, + efficiency: 8, + threads: 24, + }, + }, + }, + "dual Intel64 Family 6 Model 85 Stepping 4": { + input: []byte{ + 0x3, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x3, 0xb, 0x40, 0x0, 0x0, 0x0, 0xb8, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x3, 0xb, 0x40, 0x0, 0x0, 0x0, 0xb8, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x7, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x1, 0x8, 0x40, 0x0, 0x0, 0x80, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x38, 0x0, 0x0, 0x0, 0x2, 0x10, 0x40, 0x0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x30, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x0, 0x0, 0x80, 0x0, 0x0, 0x0, 0x2, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x28, 0x28, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, 0x28, 0x28, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, 0x0, 0x0, + }, + expected: []pkgs{ + { + cores: 40, + efficiency: 0, + threads: 80, + }, { + cores: 40, + efficiency: 0, + threads: 80, + }, + }, + }, + } + + for k, v := range testCases { + t.Run(k, func(t *testing.T) { + resp := processSystemLogicalProcessorInforationList(v.input) + if len(resp) != len(v.expected) { + t.Fatalf("incorrect number of packages %d, got %d", v.expected, len(resp)) + } + for i, pkg := range v.expected { + if resp[i].coreCount != pkg.cores { + t.Fatalf("[%d] expected core count %d, got %d", i, pkg.cores, resp[i].coreCount) + } + if resp[i].efficiencyCoreCount != pkg.efficiency { + t.Fatalf("[%d] expected efficiency count %d, got %d", i, pkg.efficiency, resp[i].efficiencyCoreCount) + } + if resp[i].threadCount != pkg.threads { + t.Fatalf("[%d] expected thread count %d, got %d", i, pkg.threads, resp[i].threadCount) + } + } + }) + } +} diff --git a/gpu/types.go b/discover/types.go similarity index 66% rename from gpu/types.go rename to discover/types.go index e623c3766..417f8ce7f 100644 --- a/gpu/types.go +++ b/discover/types.go @@ -1,20 +1,21 @@ -package gpu +package discover import ( "fmt" "log/slog" "github.com/ollama/ollama/format" + "github.com/ollama/ollama/runners" ) type memInfo struct { TotalMemory uint64 `json:"total_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"` - FreeSwap uint64 `json:"free_swap,omitempty"` + FreeSwap uint64 `json:"free_swap,omitempty"` // TODO split this out for system only } // Beginning of an `ollama info` command -type GpuInfo struct { +type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? memInfo Library string `json:"library,omitempty"` @@ -25,7 +26,7 @@ type GpuInfo struct { MinimumMemory uint64 `json:"-"` // Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly - DependencyPath string `json:"lib_path,omitempty"` + DependencyPath []string `json:"lib_path,omitempty"` // Extra environment variables specific to the GPU as list of [key,value] EnvWorkarounds [][2]string `json:"envs,omitempty"` @@ -47,8 +48,26 @@ type GpuInfo struct { // TODO other performance capability info to help in scheduling decisions } +func (gpu GpuInfo) RunnerName() string { + if gpu.Variant != "" { + return gpu.Library + "_" + gpu.Variant + } + return gpu.Library +} + type CPUInfo struct { GpuInfo + CPUs []CPU +} + +// CPU type represents a CPU Package occupying a socket +type CPU struct { + ID string `cpuinfo:"processor"` + VendorID string `cpuinfo:"vendor_id"` + ModelName string `cpuinfo:"model name"` + CoreCount int + EfficiencyCoreCount int // Performance = CoreCount - Efficiency + ThreadCount int } type CudaGPUInfo struct { @@ -83,6 +102,11 @@ type VulkanGPUInfoList []VulkanGPUInfo type GpuInfoList []GpuInfo +type UnsupportedGPUInfo struct { + GpuInfo + Reason string `json:"reason"` +} + // Split up the set of gpu info's by Library and variant func (l GpuInfoList) ByLibrary() []GpuInfoList { resp := []GpuInfoList{} @@ -90,7 +114,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList { for _, info := range l { found := false requested := info.Library - if info.Variant != CPUCapabilityNone.String() { + if info.Variant != runners.CPUCapabilityNone.String() { requested += "_" + info.Variant } for i, lib := range libs { @@ -131,25 +155,37 @@ func (a ByFreeMemory) Len() int { return len(a) } func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory } -type CPUCapability uint32 - -// Override at build time when building base GPU runners -var GPURunnerCPUCapability = CPUCapabilityAVX - -const ( - CPUCapabilityNone CPUCapability = iota - CPUCapabilityAVX - CPUCapabilityAVX2 - // TODO AVX512 -) - -func (c CPUCapability) String() string { - switch c { - case CPUCapabilityAVX: - return "avx" - case CPUCapabilityAVX2: - return "avx2" - default: - return "no vector extensions" - } +type SystemInfo struct { + System CPUInfo `json:"system"` + GPUs []GpuInfo `json:"gpus"` + UnsupportedGPUs []UnsupportedGPUInfo `json:"unsupported_gpus"` + DiscoveryErrors []string `json:"discovery_errors"` +} + +// Return the optimal number of threads to use for inference +func (si SystemInfo) GetOptimalThreadCount() int { + if len(si.System.CPUs) == 0 { + return 0 + } + + coreCount := 0 + for _, c := range si.System.CPUs { + coreCount += c.CoreCount - c.EfficiencyCoreCount + } + + return coreCount +} + +// For each GPU, check if it does NOT support flash attention +func (l GpuInfoList) FlashAttentionSupported() bool { + for _, gpu := range l { + supportsFA := gpu.Library == "metal" || + (gpu.Library == "cuda" && gpu.DriverMajor >= 7) || + gpu.Library == "rocm" + + if !supportsFA { + return false + } + } + return true } diff --git a/gpu/vulkan_common.go b/discover/vulkan_common.go similarity index 96% rename from gpu/vulkan_common.go rename to discover/vulkan_common.go index 8d3d15d06..4dccbade9 100644 --- a/gpu/vulkan_common.go +++ b/discover/vulkan_common.go @@ -1,4 +1,4 @@ -package gpu +package discover import ( "log/slog" diff --git a/docs/api.md b/docs/api.md index 95e79e007..50402ecc7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -13,6 +13,7 @@ - [Push a Model](#push-a-model) - [Generate Embeddings](#generate-embeddings) - [List Running Models](#list-running-models) +- [Version](#version) ## Conventions @@ -45,14 +46,18 @@ Generate a response for a given prompt with a provided model. This is a streamin Advanced parameters (optional): -- `format`: the format to return a response in. Currently the only accepted value is `json` +- `format`: the format to return a response in. Format can be `json` or a JSON schema - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `system`: system message to (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`) -- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) +- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory + +#### Structured outputs + +Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [structured outputs](#request-structured-outputs) example below. #### JSON mode @@ -69,7 +74,7 @@ Enable JSON mode by setting the `format` parameter to `json`. This will structur ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt": "Why is the sky blue?" }' ``` @@ -80,7 +85,7 @@ A stream of JSON objects is returned: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T08:52:19.385406455-07:00", "response": "The", "done": false @@ -102,7 +107,7 @@ To calculate how fast the response is generated in tokens per second (token/s), ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T19:22:45.499127Z", "response": "", "done": true, @@ -124,7 +129,7 @@ A response can be received in one reply when streaming is off. ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt": "Why is the sky blue?", "stream": false }' @@ -136,7 +141,7 @@ If `stream` is set to `false`, the response will be a single JSON object: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", "done": true, @@ -185,6 +190,52 @@ curl http://localhost:11434/api/generate -d '{ } ``` +#### Request (Structured outputs) + +##### Request + +```shell +curl -X POST http://localhost:11434/api/generate -H "Content-Type: application/json" -d '{ + "model": "llama3.1:8b", + "prompt": "Ollama is 22 years old and is busy saving the world. Respond using JSON", + "stream": false, + "format": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "available": { + "type": "boolean" + } + }, + "required": [ + "age", + "available" + ] + } +}' +``` + +##### Response + +```json +{ + "model": "llama3.1:8b", + "created_at": "2024-12-06T00:48:09.983619Z", + "response": "{\n \"age\": 22,\n \"available\": true\n}", + "done": true, + "done_reason": "stop", + "context": [1, 2, 3], + "total_duration": 1075509083, + "load_duration": 567678166, + "prompt_eval_count": 28, + "prompt_eval_duration": 236000000, + "eval_count": 16, + "eval_duration": 269000000 +} +``` + #### Request (JSON mode) > [!IMPORTANT] @@ -194,7 +245,7 @@ curl http://localhost:11434/api/generate -d '{ ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt": "What color is the sky at different times of the day? Respond using JSON", "format": "json", "stream": false @@ -205,7 +256,7 @@ curl http://localhost:11434/api/generate -d '{ ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-11-09T21:07:55.186497Z", "response": "{\n\"morning\": {\n\"color\": \"blue\"\n},\n\"noon\": {\n\"color\": \"blue-gray\"\n},\n\"afternoon\": {\n\"color\": \"warm gray\"\n},\n\"evening\": {\n\"color\": \"orange\"\n}\n}\n", "done": true, @@ -327,7 +378,7 @@ If you want to set custom options for the model at runtime rather than in the Mo ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt": "Why is the sky blue?", "stream": false, "options": { @@ -337,7 +388,6 @@ curl http://localhost:11434/api/generate -d '{ "top_k": 20, "top_p": 0.9, "min_p": 0.0, - "tfs_z": 0.5, "typical_p": 0.7, "repeat_last_n": 33, "temperature": 0.8, @@ -355,7 +405,6 @@ curl http://localhost:11434/api/generate -d '{ "num_gpu": 1, "main_gpu": 0, "low_vram": false, - "f16_kv": true, "vocab_only": false, "use_mmap": true, "use_mlock": false, @@ -368,7 +417,7 @@ curl http://localhost:11434/api/generate -d '{ ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T19:22:45.499127Z", "response": "The sky is blue because it is the color of the sky.", "done": true, @@ -390,7 +439,7 @@ If an empty prompt is provided, the model will be loaded into memory. ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1" + "model": "llama3.2" }' ``` @@ -400,7 +449,7 @@ A single JSON object is returned: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-12-18T19:52:07.071755Z", "response": "", "done": true @@ -415,7 +464,7 @@ If an empty prompt is provided and the `keep_alive` parameter is set to `0`, a m ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "keep_alive": 0 }' ``` @@ -426,7 +475,7 @@ A single JSON object is returned: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2024-09-12T03:54:03.516566Z", "response": "", "done": true, @@ -457,11 +506,15 @@ The `message` object has the following fields: Advanced parameters (optional): -- `format`: the format to return a response in. Currently the only accepted value is `json` +- `format`: the format to return a response in. Format can be `json` or a JSON schema. - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) +### Structured outputs + +Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below. + ### Examples #### Chat Request (Streaming) @@ -472,7 +525,7 @@ Send a chat message with a streaming response. ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "user", @@ -488,7 +541,7 @@ A stream of JSON objects is returned: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T08:52:19.385406455-07:00", "message": { "role": "assistant", @@ -503,7 +556,7 @@ Final response: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T19:22:45.499127Z", "done": true, "total_duration": 4883583458, @@ -521,7 +574,7 @@ Final response: ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "user", @@ -536,7 +589,7 @@ curl http://localhost:11434/api/chat -d '{ ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-12-12T14:13:43.416799Z", "message": { "role": "assistant", @@ -552,6 +605,54 @@ curl http://localhost:11434/api/chat -d '{ } ``` +#### Chat request (Structured outputs) + +##### Request + +```shell +curl -X POST http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{ + "model": "llama3.1", + "messages": [{"role": "user", "content": "Ollama is 22 years old and busy saving the world. Return a JSON object with the age and availability."}], + "stream": false, + "format": { + "type": "object", + "properties": { + "age": { + "type": "integer" + }, + "available": { + "type": "boolean" + } + }, + "required": [ + "age", + "available" + ] + }, + "options": { + "temperature": 0 + } +}' +``` + +##### Response + +```json +{ + "model": "llama3.1", + "created_at": "2024-12-06T00:46:58.265747Z", + "message": { "role": "assistant", "content": "{\"age\": 22, \"available\": false}" }, + "done_reason": "stop", + "done": true, + "total_duration": 2254970291, + "load_duration": 574751416, + "prompt_eval_count": 34, + "prompt_eval_duration": 1502000000, + "eval_count": 12, + "eval_duration": 175000000 +} +``` + #### Chat request (With History) Send a chat message with a conversation history. You can use this same approach to start the conversation using multi-shot or chain-of-thought prompting. @@ -560,7 +661,7 @@ Send a chat message with a conversation history. You can use this same approach ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "user", @@ -584,7 +685,7 @@ A stream of JSON objects is returned: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T08:52:19.385406455-07:00", "message": { "role": "assistant", @@ -598,7 +699,7 @@ Final response: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-08-04T19:22:45.499127Z", "done": true, "total_duration": 8113331500, @@ -656,7 +757,7 @@ curl http://localhost:11434/api/chat -d '{ ```shell curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "user", @@ -674,7 +775,7 @@ curl http://localhost:11434/api/chat -d '{ ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2023-12-12T14:13:43.416799Z", "message": { "role": "assistant", @@ -696,7 +797,7 @@ curl http://localhost:11434/api/chat -d '{ ``` curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "user", @@ -735,7 +836,7 @@ curl http://localhost:11434/api/chat -d '{ ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at": "2024-07-22T20:33:28.123648Z", "message": { "role": "assistant", @@ -771,7 +872,7 @@ If the messages array is empty, the model will be loaded into memory. ``` curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [] }' ``` @@ -779,7 +880,7 @@ curl http://localhost:11434/api/chat -d '{ ##### Response ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at":"2024-09-12T21:17:29.110811Z", "message": { "role": "assistant", @@ -798,7 +899,7 @@ If the messages array is empty and the `keep_alive` parameter is set to `0`, a m ``` curl http://localhost:11434/api/chat -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [], "keep_alive": 0 }' @@ -810,7 +911,7 @@ A single JSON object is returned: ```json { - "model": "llama3.1", + "model": "llama3.2", "created_at":"2024-09-12T21:33:17.547535Z", "message": { "role": "assistant", @@ -831,10 +932,30 @@ Create a model from a [`Modelfile`](./modelfile.md). It is recommended to set `m ### Parameters -- `name`: name of the model to create +- `model`: name of the model to create - `modelfile` (optional): contents of the Modelfile - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects - `path` (optional): path to the Modelfile +- `quantize` (optional): quantize a non-quantized (e.g. float16) model + +#### Quantization types + +| Type | Recommended | +| --- | :-: | +| q2_K | | +| q3_K_L | | +| q3_K_M | | +| q3_K_S | | +| q4_0 | | +| q4_1 | | +| q4_K_M | * | +| q4_K_S | | +| q5_0 | | +| q5_1 | | +| q5_K_M | | +| q5_K_S | | +| q6_K | | +| q8_0 | * | ### Examples @@ -846,14 +967,14 @@ Create a new model from a `Modelfile`. ```shell curl http://localhost:11434/api/create -d '{ - "name": "mario", + "model": "mario", "modelfile": "FROM llama3\nSYSTEM You are mario from Super Mario Bros." }' ``` ##### Response -A stream of JSON objects. Notice that the final JSON object shows a `"status": "success"`. +A stream of JSON objects is returned: ```json {"status":"reading model metadata"} @@ -869,13 +990,43 @@ A stream of JSON objects. Notice that the final JSON object shows a `"status": " {"status":"success"} ``` +#### Quantize a model + +Quantize a non-quantized model. + +##### Request + +```shell +curl http://localhost:11434/api/create -d '{ + "model": "llama3.1:quantized", + "modelfile": "FROM llama3.1:8b-instruct-fp16", + "quantize": "q4_K_M" +}' +``` + +##### Response + +A stream of JSON objects is returned: + +``` +{"status":"quantizing F16 model to Q4_K_M"} +{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"} +{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"} +{"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"} +{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"} +{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"} +{"status":"writing manifest"} +{"status":"success"} +``` + + ### Check if a Blob Exists ```shell HEAD /api/blobs/:digest ``` -Ensures that the file blob used for a FROM or ADAPTER field exists on the server. This is checking your Ollama server and not Ollama.ai. +Ensures that the file blob used for a FROM or ADAPTER field exists on the server. This is checking your Ollama server and not ollama.com. #### Query Parameters @@ -980,7 +1131,7 @@ Show information about a model including details, modelfile, template, parameter ### Parameters -- `name`: name of the model to show +- `model`: name of the model to show - `verbose`: (optional) if set to `true`, returns full data for verbose response fields ### Examples @@ -989,7 +1140,7 @@ Show information about a model including details, modelfile, template, parameter ```shell curl http://localhost:11434/api/show -d '{ - "name": "llama3.1" + "model": "llama3.2" }' ``` @@ -1050,7 +1201,7 @@ Copy a model. Creates a model with another name from an existing model. ```shell curl http://localhost:11434/api/copy -d '{ - "source": "llama3.1", + "source": "llama3.2", "destination": "llama3-backup" }' ``` @@ -1069,7 +1220,7 @@ Delete a model and its data. ### Parameters -- `name`: model name to delete +- `model`: model name to delete ### Examples @@ -1077,7 +1228,7 @@ Delete a model and its data. ```shell curl -X DELETE http://localhost:11434/api/delete -d '{ - "name": "llama3:13b" + "model": "llama3:13b" }' ``` @@ -1095,7 +1246,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where ### Parameters -- `name`: name of the model to pull +- `model`: name of the model to pull - `insecure`: (optional) allow insecure connections to the library. Only use this if you are pulling from your own library during development. - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects @@ -1105,7 +1256,7 @@ Download a model from the ollama library. Cancelled pulls are resumed from where ```shell curl http://localhost:11434/api/pull -d '{ - "name": "llama3.1" + "model": "llama3.2" }' ``` @@ -1167,7 +1318,7 @@ Upload a model to a model library. Requires registering for ollama.ai and adding ### Parameters -- `name`: name of the model to push in the form of `/:` +- `model`: name of the model to push in the form of `/:` - `insecure`: (optional) allow insecure connections to the library. Only use this if you are pushing to your library during development. - `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects @@ -1177,7 +1328,7 @@ Upload a model to a model library. Requires registering for ollama.ai and adding ```shell curl http://localhost:11434/api/push -d '{ - "name": "mattw/pygmalion:latest" + "model": "mattw/pygmalion:latest" }' ``` @@ -1376,3 +1527,29 @@ curl http://localhost:11434/api/embeddings -d '{ ] } ``` + +## Version + +```shell +GET /api/version +``` + +Retrieve the Ollama version + +### Examples + +#### Request + +```shell +curl http://localhost:11434/api/version +``` + +#### Response + +```json +{ + "version": "0.5.1" +} +``` + + diff --git a/docs/development.md b/docs/development.md index e67689abc..e194dca0b 100644 --- a/docs/development.md +++ b/docs/development.md @@ -2,36 +2,23 @@ Install required tools: -- cmake version 3.24 or higher - go version 1.22 or higher -- gcc version 11.4.0 or higher +- OS specific C/C++ compiler (see below) +- GNU Make + + +## Overview + +Ollama uses a mix of Go and C/C++ code to interface with GPUs. The C/C++ code is compiled with both CGO and GPU library specific compilers. A set of GNU Makefiles are used to compile the project. GPU Libraries are auto-detected based on the typical environment variables used by the respective libraries, but can be overridden if necessary. The default make target will build the runners and primary Go Ollama application that will run within the repo directory. Throughout the examples below `-j 5` is suggested for 5 parallel jobs to speed up the build. You can adjust the job count based on your CPU Core count to reduce build times. If you want to relocate the built binaries, use the `dist` target and recursively copy the files in `./dist/$OS-$ARCH/` to your desired location. To learn more about the other make targets use `make help` + +Once you have built the GPU/CPU runners, you can compile the main application with `go build .` ### MacOS -```bash -brew install go cmake gcc -``` - -Optionally enable debugging and more verbose logging: +[Download Go](https://go.dev/dl/) ```bash -# At build time -export CGO_CFLAGS="-g" - -# At runtime -export OLLAMA_DEBUG=1 -``` - -Get the required libraries and build the native LLM code: - -```bash -go generate ./... -``` - -Then build ollama: - -```bash -go build . +make -j 5 ``` Now you can run `ollama`: @@ -40,114 +27,96 @@ Now you can run `ollama`: ./ollama ``` +#### Xcode 15 warnings + +If you are using Xcode newer than version 14, you may see a warning during `go build` about `ld: warning: ignoring duplicate libraries: '-lobjc'` due to Golang issue https://github.com/golang/go/issues/67799 which can be safely ignored. You can suppress the warning with `export CGO_LDFLAGS="-Wl,-no_warn_duplicate_libraries"` + ### Linux #### Linux CUDA (NVIDIA) _Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_ -Install `cmake` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) +Install `make`, `gcc` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) development and runtime packages. -Typically the build scripts will auto-detect CUDA, however, if your Linux distro -or installation approach uses unusual paths, you can specify the location by -specifying an environment variable `CUDA_LIB_DIR` to the location of the shared -libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize -a set of target CUDA architectures by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70") - -Then generate dependencies: +Typically the makefile will auto-detect CUDA, however, if your Linux distro +or installation approach uses alternative paths, you can specify the location by +overriding `CUDA_PATH` to the location of the CUDA toolkit. You can customize +a set of target CUDA architectures by setting `CUDA_ARCHITECTURES` (e.g. `CUDA_ARCHITECTURES=50;60;70`) ``` -go generate ./... +make -j 5 ``` -Then build the binary: +If both v11 and v12 tookkits are detected, runners for both major versions will be built by default. You can build just v12 with `make cuda_v12` -``` -go build . -``` +#### Older Linux CUDA (NVIDIA) + +To support older GPUs with Compute Capability 3.5 or 3.7, you will need to use an older version of the Driver from [Unix Driver Archive](https://www.nvidia.com/en-us/drivers/unix/) (tested with 470) and [CUDA Toolkit Archive](https://developer.nvidia.com/cuda-toolkit-archive) (tested with cuda V11). When you build Ollama, you will need to set two make variable to adjust the minimum compute capability Ollama supports via `make -j 5 CUDA_ARCHITECTURES="35;37;50;52" EXTRA_GOLDFLAGS="\"-X=github.com/ollama/ollama/discover.CudaComputeMajorMin=3\" \"-X=github.com/ollama/ollama/discover.CudaComputeMinorMin=5\""`. To find the Compute Capability of your older GPU, refer to [GPU Compute Capability](https://developer.nvidia.com/cuda-gpus). #### Linux ROCm (AMD) -_Your operating system distribution may already have packages for AMD ROCm and CLBlast. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_ +_Your operating system distribution may already have packages for AMD ROCm. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!_ -Install [CLBlast](https://github.com/CNugteren/CLBlast/blob/master/doc/installation.md) and [ROCm](https://rocm.docs.amd.com/en/latest/) development packages first, as well as `cmake` and `golang`. +Install [ROCm](https://rocm.docs.amd.com/en/latest/) development packages first, as well as `make`, `gcc`, and `golang`. Typically the build scripts will auto-detect ROCm, however, if your Linux distro or installation approach uses unusual paths, you can specify the location by -specifying an environment variable `ROCM_PATH` to the location of the ROCm -install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the -CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize -the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`) +specifying an environment variable `HIP_PATH` to the location of the ROCm +install (typically `/opt/rocm`). You can also customize +the AMD GPU targets by setting HIP_ARCHS (e.g. `HIP_ARCHS=gfx1101;gfx1102`) ``` -go generate ./... -``` - -Then build the binary: - -``` -go build . +make -j 5 ``` ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root. -#### Advanced CPU Settings - -By default, running `go generate ./...` will compile a few different variations -of the LLM library based on common CPU families and vector math capabilities, -including a lowest-common-denominator which should run on almost any 64 bit CPU -somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to -load. If you would like to build a CPU-based build customized for your -processor, you can set `OLLAMA_CUSTOM_CPU_DEFS` to the llama.cpp flags you would -like to use. For example, to compile an optimized binary for an Intel i9-9880H, -you might use: - -``` -OLLAMA_CUSTOM_CPU_DEFS="-DGGML_AVX=on -DGGML_AVX2=on -DGGML_F16C=on -DGGML_FMA=on" go generate ./... -go build . -``` - #### Containerized Linux Build -If you have Docker available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting binary is placed in `./dist` +If you have Docker and buildx available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting artifacts are placed in `./dist` and by default the script builds both arm64 and amd64 binaries. If you want to build only amd64, you can build with `PLATFORM=linux/amd64 ./scripts/build_linux.sh` ### Windows -Note: The Windows build for Ollama is still under development. +The following tools are required as a minimal development environment to build CPU inference support. -First, install required tools: - -- MSVC toolchain - C/C++ and cmake as minimal requirements - Go version 1.22 or higher -- MinGW (pick one variant) with GCC. - - [MinGW-w64](https://www.mingw-w64.org/) + - https://go.dev/dl/ +- Git + - https://git-scm.com/download/win +- clang with gcc compat and Make. There are multiple options on how to go about installing these tools on Windows. We have verified the following, but others may work as well: - [MSYS2](https://www.msys2.org/) -- The `ThreadJob` Powershell module: `Install-Module -Name ThreadJob -Scope CurrentUser` + - After installing, from an MSYS2 terminal, run `pacman -S mingw-w64-clang-x86_64-gcc-compat mingw-w64-clang-x86_64-clang make` to install the required tools + - Assuming you used the default install prefix for msys2 above, add `C:\msys64\clang64\bin` and `c:\msys64\usr\bin` to your environment variable `PATH` where you will perform the build steps below (e.g. system-wide, account-level, powershell, cmd, etc.) -Then, build the `ollama` binary: +> [!NOTE] +> Due to bugs in the GCC C++ library for unicode support, Ollama should be built with clang on windows. -```powershell -$env:CGO_ENABLED="1" -go generate ./... -go build . ``` +make -j 5 +``` + +#### GPU Support + +The GPU tools require the Microsoft native build tools. To build either CUDA or ROCm, you must first install MSVC via Visual Studio: + +- Make sure to select `Desktop development with C++` as a Workload during the Visual Studio install +- You must complete the Visual Studio install and run it once **BEFORE** installing CUDA or ROCm for the tools to properly register +- Add the location of the **64 bit (x64)** compiler (`cl.exe`) to your `PATH` +- Note: the default Developer Shell may configure the 32 bit (x86) compiler which will lead to build failures. Ollama requires a 64 bit toolchain. #### Windows CUDA (NVIDIA) -In addition to the common Windows development tools described above, install CUDA after installing MSVC. +In addition to the common Windows development tools and MSVC described above: - [NVIDIA CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) - #### Windows ROCm (AMD Radeon) -In addition to the common Windows development tools described above, install AMDs HIP package after installing MSVC. +In addition to the common Windows development tools and MSVC described above: - [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html) -- [Strawberry Perl](https://strawberryperl.com/) - -Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`). #### Windows arm64 @@ -166,4 +135,31 @@ Follow the instructions at https://www.msys2.org/wiki/arm64/ to set up an arm64 pacman -S mingw-w64-clang-aarch64-clang mingw-w64-clang-aarch64-gcc-compat mingw-w64-clang-aarch64-make make ``` -You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`) \ No newline at end of file +You will need to ensure your PATH includes go, cmake, gcc and clang mingw32-make to build ollama from source. (typically `C:\msys64\clangarm64\bin\`) + + +## Advanced CPU Vector Settings + +On x86, running `make` will compile several CPU runners which can run on different CPU families. At runtime, Ollama will auto-detect the best variation to load. If GPU libraries are present at build time, Ollama also compiles GPU runners with the `AVX` CPU vector feature enabled. This provides a good performance balance when loading large models that split across GPU and CPU with broad compatibility. Some users may prefer no vector extensions (e.g. older Xeon/Celeron processors, or hypervisors that mask the vector features) while other users may prefer turning on many more vector extensions to further improve performance for split model loads. + +To customize the set of CPU vector features enabled for a CPU runner and all GPU runners, use CUSTOM_CPU_FLAGS during the build. + +To build without any vector flags: + +``` +make CUSTOM_CPU_FLAGS="" +``` + +To build with both AVX and AVX2: +``` +make CUSTOM_CPU_FLAGS=avx,avx2 +``` + +To build with AVX512 features turned on: + +``` +make CUSTOM_CPU_FLAGS=avx,avx2,avx512,avx512vbmi,avx512vnni,avx512bf16 +``` + +> [!NOTE] +> If you are experimenting with different flags, make sure to do a `make clean` between each change to ensure everything is rebuilt with the new compiler flags diff --git a/docs/docker.md b/docs/docker.md index 314666b26..9dd387e3a 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -50,6 +50,9 @@ sudo systemctl restart docker docker run -d --gpus=all -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama ``` +> [!NOTE] +> If you're running on an NVIDIA JetPack system, Ollama can't automatically discover the correct JetPack version. Pass the environment variable JETSON_JETPACK=5 or JETSON_JETPACK=6 to the container to select version 5 or 6. + ### AMD GPU To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following command: @@ -63,7 +66,7 @@ docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 114 Now you can run a model: ``` -docker exec -it ollama ollama run llama3.1 +docker exec -it ollama ollama run llama3.2 ``` ### Try different models diff --git a/docs/faq.md b/docs/faq.md index b2b1ca304..387d752b2 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -32,7 +32,7 @@ When using the API, specify the `num_ctx` parameter: ```shell curl http://localhost:11434/api/generate -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt": "Why is the sky blue?", "options": { "num_ctx": 4096 @@ -151,7 +151,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e Ollama runs an HTTP server and can be exposed using a proxy server such as Nginx. To do so, configure the proxy to forward requests and optionally set required headers (if not exposing Ollama on the network). For example, with Nginx: -``` +```nginx server { listen 80; server_name example.com; # Replace with your domain or IP @@ -232,7 +232,7 @@ curl http://localhost:11434/api/chat -d '{"model": "mistral"}' To preload a model using the CLI, use the command: ```shell -ollama run llama3.1 "" +ollama run llama3.2 "" ``` ## How do I keep a model loaded in memory or make it unload immediately? @@ -240,7 +240,7 @@ ollama run llama3.1 "" By default models are kept in memory for 5 minutes before being unloaded. This allows for quicker response times if you're making numerous requests to the LLM. If you want to immediately unload a model from memory, use the `ollama stop` command: ```shell -ollama stop llama3.1 +ollama stop llama3.2 ``` If you're using the API, use the `keep_alive` parameter with the `/api/generate` and `/api/chat` endpoints to set the amount of time that a model stays in memory. The `keep_alive` parameter can be set to: @@ -251,12 +251,12 @@ If you're using the API, use the `keep_alive` parameter with the `/api/generate` For example, to preload a model and leave it in memory use: ```shell -curl http://localhost:11434/api/generate -d '{"model": "llama3.1", "keep_alive": -1}' +curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": -1}' ``` To unload the model and free up memory use: ```shell -curl http://localhost:11434/api/generate -d '{"model": "llama3.1", "keep_alive": 0}' +curl http://localhost:11434/api/generate -d '{"model": "llama3.2", "keep_alive": 0}' ``` Alternatively, you can change the amount of time all models are loaded into memory by setting the `OLLAMA_KEEP_ALIVE` environment variable when starting the Ollama server. The `OLLAMA_KEEP_ALIVE` variable uses the same parameter types as the `keep_alive` parameter types mentioned above. Refer to the section explaining [how to configure the Ollama server](#how-do-i-configure-ollama-server) to correctly set the environment variable. @@ -285,4 +285,28 @@ Note: Windows with Radeon GPUs currently default to 1 model maximum due to limit ## How does Ollama load models on multiple GPUs? -Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs. +When loading a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transferring across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs. + +## How can I enable Flash Attention? + +Flash Attention is a feature of most modern models that can significantly reduce memory usage as the context size grows. To enable Flash Attention, set the `OLLAMA_FLASH_ATTENTION` environment variable to `1` when starting the Ollama server. + +## How can I set the quantization type for the K/V cache? + +The K/V context cache can be quantized to significantly reduce memory usage when Flash Attention is enabled. + +To use quantized K/V cache with Ollama you can set the following environment variable: + +- `OLLAMA_KV_CACHE_TYPE` - The quantization type for the K/V cache. Default is `f16`. + +> Note: Currently this is a global option - meaning all models will run with the specified quantization type. + +The currently available K/V cache quantization types are: + +- `f16` - high precision and memory usage (default). +- `q8_0` - 8-bit quantization, uses approximately 1/2 the memory of `f16` with a very small loss in precision, this usually has no noticeable impact on the model's quality (recommended if not using f16). +- `q4_0` - 4-bit quantization, uses approximately 1/4 the memory of `f16` with a small-medium loss in precision that may be more noticeable at higher context sizes. + +How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count. + +You may need to experiment with different quantization types to find the best balance between memory usage and quality. diff --git a/docs/gpu.md b/docs/gpu.md index 2913a2e27..691746d0d 100644 --- a/docs/gpu.md +++ b/docs/gpu.md @@ -28,6 +28,7 @@ Check your compute compatibility to see if your card is supported: | 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` | | | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` | +For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia) ### GPU Selection @@ -74,6 +75,10 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the server. If you have an unsupported AMD GPU you can experiment using the list of supported types below. +If you have multiple GPUs with different GFX versions, append the numeric device +number to the environment variable to set them individually. For example, +`HSA_OVERRIDE_GFX_VERSION_0=10.3.0` and `HSA_OVERRIDE_GFX_VERSION_1=11.0.0` + At this time, the known supported GPU types on linux are the following LLVM Targets. This table shows some example GPUs that map to these LLVM targets: | **LLVM Target** | **An Example GPU** | @@ -99,9 +104,10 @@ Reach out on [Discord](https://discord.gg/ollama) or file an ### GPU Selection If you have multiple AMD GPUs in your system and want to limit Ollama to use a -subset, you can set `HIP_VISIBLE_DEVICES` to a comma separated list of GPUs. +subset, you can set `ROCR_VISIBLE_DEVICES` to a comma separated list of GPUs. You can see the list of devices with `rocminfo`. If you want to ignore the GPUs -and force CPU usage, use an invalid GPU ID (e.g., "-1") +and force CPU usage, use an invalid GPU ID (e.g., "-1"). When available, use the +`Uuid` to uniquely identify the device instead of numeric value. ### Container Permission diff --git a/docs/import.md b/docs/import.md index 2346886f1..040fa299e 100644 --- a/docs/import.md +++ b/docs/import.md @@ -32,7 +32,7 @@ ollama run my-model Ollama supports importing adapters based on several different model architectures including: - * Llama (including Llama 2, Llama 3, and Llama 3.1); + * Llama (including Llama 2, Llama 3, Llama 3.1, and Llama 3.2); * Mistral (including Mistral 1, Mistral 2, and Mixtral); and * Gemma (including Gemma 1 and Gemma 2) @@ -67,14 +67,12 @@ ollama run my-model Ollama supports importing models for several different architectures including: - * Llama (including Llama 2, Llama 3, and Llama 3.1); + * Llama (including Llama 2, Llama 3, Llama 3.1, and Llama 3.2); * Mistral (including Mistral 1, Mistral 2, and Mixtral); * Gemma (including Gemma 1 and Gemma 2); and * Phi3 -This includes importing foundation models as well as any fine tuned models which which have been _fused_ with a foundation model. - - +This includes importing foundation models as well as any fine tuned models which have been _fused_ with a foundation model. ## Importing a GGUF based model or adapter If you have a GGUF based model or adapter it is possible to import it into Ollama. You can obtain a GGUF model or adapter by: @@ -83,7 +81,7 @@ If you have a GGUF based model or adapter it is possible to import it into Ollam * converting a Safetensors adapter with the `convert_lora_to_gguf.py` from Llama.cpp; or * downloading a model or adapter from a place such as HuggingFace -To import a GGUF model, create a `Modelfile` containg: +To import a GGUF model, create a `Modelfile` containing: ```dockerfile FROM /path/to/file.gguf diff --git a/docs/linux.md b/docs/linux.md index 0eec014f4..13655f423 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -10,6 +10,9 @@ curl -fsSL https://ollama.com/install.sh | sh ## Manual install +> [!NOTE] +> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. + Download and extract the package: ```shell @@ -112,6 +115,21 @@ sudo systemctl status ollama > https://www.amd.com/en/support/linux-drivers for best support of your Radeon > GPU. +## Customizing + +To customize the installation of Ollama, you can edit the systemd service file or the environment variables by running: + +``` +sudo systemctl edit ollama +``` + +Alternatively, create an override file manually in `/etc/systemd/system/ollama.service.d/override.conf`: + +```ini +[Service] +Environment="OLLAMA_DEBUG=1" +``` + ## Updating Update Ollama by running the install script again: @@ -129,7 +147,7 @@ sudo tar -C /usr -xzf ollama-linux-amd64.tgz ## Installing specific versions -Use `OLLAMA_VERSION` environment variable with the install script to install a specific version of Ollama, including pre-releases. You can find the version numbers in the [releases page](https://github.com/ollama/ollama/releases). +Use `OLLAMA_VERSION` environment variable with the install script to install a specific version of Ollama, including pre-releases. You can find the version numbers in the [releases page](https://github.com/ollama/ollama/releases). For example: diff --git a/docs/modelfile.md b/docs/modelfile.md index a33f180b7..b1c4e8a3c 100644 --- a/docs/modelfile.md +++ b/docs/modelfile.md @@ -50,7 +50,7 @@ INSTRUCTION arguments An example of a `Modelfile` creating a mario blueprint: ```modelfile -FROM llama3.1 +FROM llama3.2 # sets the temperature to 1 [higher is more creative, lower is more coherent] PARAMETER temperature 1 # sets the context window size to 4096, this controls how many tokens the LLM can use as context to generate the next token @@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant. To use this: 1. Save it as a file (e.g. `Modelfile`) -2. `ollama create choose-a-model-name -f '` +2. `ollama create choose-a-model-name -f ` 3. `ollama run choose-a-model-name` 4. Start using the model! @@ -72,10 +72,10 @@ More examples are available in the [examples directory](../examples). To view the Modelfile of a given model, use the `ollama show --modelfile` command. ```bash - > ollama show --modelfile llama3.1 + > ollama show --modelfile llama3.2 # Modelfile generated by "ollama show" # To build a new Modelfile based on this one, replace the FROM line with: - # FROM llama3.1:latest + # FROM llama3.2:latest FROM /Users/pdevine/.ollama/models/blobs/sha256-00e1317cbf74d901080d7100f57580ba8dd8de57203072dc6f668324ba545f29 TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|> @@ -103,7 +103,7 @@ FROM : #### Build from existing model ```modelfile -FROM llama3.1 +FROM llama3.2 ``` A list of available base models: @@ -120,7 +120,7 @@ FROM The model directory should contain the Safetensors weights for a supported architecture. Currently supported model architectures: - * Llama (including Llama 2, Llama 3, and Llama 3.1) + * Llama (including Llama 2, Llama 3, Llama 3.1, and Llama 3.2) * Mistral (including Mistral 1, Mistral 2, and Mixtral) * Gemma (including Gemma 1 and Gemma 2) * Phi3 @@ -156,7 +156,7 @@ PARAMETER | seed | Sets the random number seed to use for generation. Setting this to a specific number will make the model generate the same text for the same prompt. (Default: 0) | int | seed 42 | | stop | Sets the stop sequences to use. When this pattern is encountered the LLM will stop generating text and return. Multiple stop patterns may be set by specifying multiple separate `stop` parameters in a modelfile. | string | stop "AI assistant:" | | tfs_z | Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting. (default: 1) | float | tfs_z 1 | -| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 | +| num_predict | Maximum number of tokens to predict when generating text. (Default: -1, infinite generation) | int | num_predict 42 | | top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 | | top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 | | min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 | diff --git a/docs/openai.md b/docs/openai.md index c6df0fecb..b0f9b353c 100644 --- a/docs/openai.md +++ b/docs/openai.md @@ -25,7 +25,7 @@ chat_completion = client.chat.completions.create( 'content': 'Say this is a test', } ], - model='llama3.1', + model='llama3.2', ) response = client.chat.completions.create( @@ -37,7 +37,7 @@ response = client.chat.completions.create( {"type": "text", "text": "What's in this image?"}, { "type": "image_url", - "image_url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC", + "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC", }, ], } @@ -46,19 +46,53 @@ response = client.chat.completions.create( ) completion = client.completions.create( - model="llama3.1", + model="llama3.2", prompt="Say this is a test", ) list_completion = client.models.list() -model = client.models.retrieve("llama3.1") +model = client.models.retrieve("llama3.2") embeddings = client.embeddings.create( model="all-minilm", input=["why is the sky blue?", "why is the grass green?"], ) ``` +#### Structured outputs +```py +from pydantic import BaseModel +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama") + +# Define the schema for the response +class FriendInfo(BaseModel): + name: str + age: int + is_available: bool + +class FriendList(BaseModel): + friends: list[FriendInfo] + +try: + completion = client.beta.chat.completions.parse( + temperature=0, + model="llama3.1:8b", + messages=[ + {"role": "user", "content": "I have two friends. The first is Ollama 22 years old busy saving the world, and the second is Alonso 23 years old and wants to hang out. Return a list of friends in JSON format"} + ], + response_format=FriendList, + ) + + friends_response = completion.choices[0].message + if friends_response.parsed: + print(friends_response.parsed) + elif friends_response.refusal: + print(friends_response.refusal) +except Exception as e: + print(f"Error: {e}") +``` ### OpenAI JavaScript library @@ -74,7 +108,7 @@ const openai = new OpenAI({ const chatCompletion = await openai.chat.completions.create({ messages: [{ role: 'user', content: 'Say this is a test' }], - model: 'llama3.1', + model: 'llama3.2', }) const response = await openai.chat.completions.create({ @@ -86,7 +120,7 @@ const response = await openai.chat.completions.create({ { type: "text", text: "What's in this image?" }, { type: "image_url", - image_url: "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC", + image_url: "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC", }, ], }, @@ -94,13 +128,13 @@ const response = await openai.chat.completions.create({ }) const completion = await openai.completions.create({ - model: "llama3.1", + model: "llama3.2", prompt: "Say this is a test.", }) const listCompletion = await openai.models.list() -const model = await openai.models.retrieve("llama3.1") +const model = await openai.models.retrieve("llama3.2") const embedding = await openai.embeddings.create({ model: "all-minilm", @@ -114,7 +148,7 @@ const embedding = await openai.embeddings.create({ curl http://localhost:11434/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "llama3.1", + "model": "llama3.2", "messages": [ { "role": "system", @@ -142,7 +176,7 @@ curl http://localhost:11434/v1/chat/completions \ { "type": "image_url", "image_url": { - "url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC" + "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC" } } ] @@ -154,13 +188,13 @@ curl http://localhost:11434/v1/chat/completions \ curl http://localhost:11434/v1/completions \ -H "Content-Type: application/json" \ -d '{ - "model": "llama3.1", + "model": "llama3.2", "prompt": "Say this is a test" }' curl http://localhost:11434/v1/models -curl http://localhost:11434/v1/models/llama3.1 +curl http://localhost:11434/v1/models/llama3.2 curl http://localhost:11434/v1/embeddings \ -H "Content-Type: application/json" \ @@ -181,7 +215,7 @@ curl http://localhost:11434/v1/embeddings \ - [x] JSON mode - [x] Reproducible outputs - [x] Vision -- [x] Tools (streaming support coming soon) +- [x] Tools - [ ] Logprobs #### Supported request fields @@ -199,6 +233,8 @@ curl http://localhost:11434/v1/embeddings \ - [x] `seed` - [x] `stop` - [x] `stream` +- [x] `stream_options` + - [x] `include_usage` - [x] `temperature` - [x] `top_p` - [x] `max_tokens` @@ -227,6 +263,8 @@ curl http://localhost:11434/v1/embeddings \ - [x] `seed` - [x] `stop` - [x] `stream` +- [x] `stream_options` + - [x] `include_usage` - [x] `temperature` - [x] `top_p` - [x] `max_tokens` @@ -274,7 +312,7 @@ curl http://localhost:11434/v1/embeddings \ Before using a model, pull it locally `ollama pull`: ```shell -ollama pull llama3.1 +ollama pull llama3.2 ``` ### Default model names @@ -282,7 +320,7 @@ ollama pull llama3.1 For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name: ``` -ollama cp llama3.1 gpt-3.5-turbo +ollama cp llama3.2 gpt-3.5-turbo ``` Afterwards, this new model name can be specified the `model` field: diff --git a/docs/template.md b/docs/template.md index 192d878d4..47e1399a2 100644 --- a/docs/template.md +++ b/docs/template.md @@ -33,7 +33,7 @@ Omitting a template in these models puts the responsibility of correctly templat To add templates in your model, you'll need to add a `TEMPLATE` command to the Modelfile. Here's an example using Meta's Llama 3. ```dockerfile -FROM llama3.1 +FROM llama3.2 TEMPLATE """{{- if .System }}<|start_header_id|>system<|end_header_id|> @@ -111,7 +111,7 @@ Keep the following tips and best practices in mind when working with Go template ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2. -```gotmpl +```go {{- range .Messages }}<|im_start|>{{ .Role }} {{ .Content }}<|im_end|> {{ end }}<|im_start|>assistant @@ -125,7 +125,7 @@ Tools support can be added to a model by adding a `{{ .Tools }}` node to the tem Mistral v0.3 and Mixtral 8x22B supports tool calling. -```gotmpl +```go {{- range $index, $_ := .Messages }} {{- if eq .Role "user" }} {{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS] @@ -151,7 +151,7 @@ Fill-in-middle support can be added to a model by adding a `{{ .Suffix }}` node CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://ollama.com/library/codellama:13b-code) code completion models support fill-in-middle. -```gotmpl +```go
 {{ .Prompt }} {{ .Suffix }} 
 ```
 
diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md
index 0a89b87f9..28f4350aa 100644
--- a/docs/troubleshooting.md
+++ b/docs/troubleshooting.md
@@ -80,7 +80,7 @@ If you are using a container to run Ollama, make sure you've set up the containe
 
 Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
 
-- If you are using a container, is the container runtime working?  Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
+- If you are using a container, is the container runtime working?  Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama won't be able to see your NVIDIA GPU.
 - Is the uvm driver loaded? `sudo nvidia-modprobe -u`
 - Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
 - Try rebooting
@@ -95,13 +95,21 @@ If none of those resolve the problem, gather additional information and file an
 
 On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device.  If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
 
-When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU.  Use `ls -ld /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the group assignments on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices.
+When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU.  Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices.   For example, in the following output `crw-rw---- 1 0  44 226,   0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44` 
+
+If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker.  Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
 
 If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
 - `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries.  This can help show more detailed error codes that can help troubleshoot problems
 - `OLLAMA_DEBUG=1` During GPU discovery additional information will be reported
 - Check dmesg for any errors from amdgpu or kfd drivers `sudo dmesg | grep -i amdgpu` and `sudo dmesg | grep -i kfd`
 
+## Multiple AMD GPUs
+
+If you experience gibberish responses when models load across multiple AMD GPUs on Linux, see the following guide.
+
+- https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/mgpu.html#mgpu-known-issues-and-limitations
+
 ## Windows Terminal Errors
 
 Older versions of Windows 10 (e.g., 21H1) are known to have a bug where the standard terminal program does not display control characters correctly.  This can result in a long string of strings like `←[?25h←[?25l` being displayed, sometimes erroring with `The parameter is incorrect`  To resolve this problem, please update to Win 10 22H1 or newer.
diff --git a/docs/tutorials.md b/docs/tutorials.md
deleted file mode 100644
index 0f520c952..000000000
--- a/docs/tutorials.md
+++ /dev/null
@@ -1,9 +0,0 @@
-# Tutorials
-
-Here is a list of ways you can use Ollama with other tools to build interesting applications.
-
-- [Using LangChain with Ollama in JavaScript](./tutorials/langchainjs.md)
-- [Using LangChain with Ollama in Python](./tutorials/langchainpy.md)
-- [Running Ollama on NVIDIA Jetson Devices](./tutorials/nvidia-jetson.md)
-
-Also be sure to check out the [examples](../examples) directory for more ways to use Ollama.
diff --git a/docs/tutorials/fly-gpu.md b/docs/tutorials/fly-gpu.md
deleted file mode 100644
index 24802ddb1..000000000
--- a/docs/tutorials/fly-gpu.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# Running Ollama on Fly.io GPU Instances
-
-Ollama runs with little to no configuration on [Fly.io GPU instances](https://fly.io/docs/gpus/gpu-quickstart/). If you don't have access to GPUs yet, you'll need to [apply for access](https://fly.io/gpu/) on the waitlist. Once you're accepted, you'll get an email with instructions on how to get started.
-
-Create a new app with `fly apps create`:
-
-```bash
-fly apps create
-```
-
-Then create a `fly.toml` file in a new folder that looks like this:
-
-```toml
-app = "sparkling-violet-709"
-primary_region = "ord"
-vm.size = "a100-40gb" # see https://fly.io/docs/gpus/gpu-quickstart/ for more info
-
-[build]
-  image = "ollama/ollama"
-
-[http_service]
-  internal_port = 11434
-  force_https = false
-  auto_stop_machines = true
-  auto_start_machines = true
-  min_machines_running = 0
-  processes = ["app"]
-
-[mounts]
-  source = "models"
-  destination = "/root/.ollama"
-  initial_size = "100gb"
-```
-
-Then create a [new private IPv6 address](https://fly.io/docs/reference/private-networking/#flycast-private-load-balancing) for your app:
-
-```bash
-fly ips allocate-v6 --private
-```
-
-Then deploy your app:
-
-```bash
-fly deploy
-```
-
-And finally you can access it interactively with a new Fly.io Machine:
-
-```
-fly machine run -e OLLAMA_HOST=http://your-app-name.flycast --shell ollama/ollama
-```
-
-```bash
-$ ollama run openchat:7b-v3.5-fp16
->>> How do I bake chocolate chip cookies?
- To bake chocolate chip cookies, follow these steps:
-
-1. Preheat the oven to 375°F (190°C) and line a baking sheet with parchment paper or silicone baking mat.
-
-2. In a large bowl, mix together 1 cup of unsalted butter (softened), 3/4 cup granulated sugar, and 3/4
-cup packed brown sugar until light and fluffy.
-
-3. Add 2 large eggs, one at a time, to the butter mixture, beating well after each addition. Stir in 1
-teaspoon of pure vanilla extract.
-
-4. In a separate bowl, whisk together 2 cups all-purpose flour, 1/2 teaspoon baking soda, and 1/2 teaspoon
-salt. Gradually add the dry ingredients to the wet ingredients, stirring until just combined.
-
-5. Fold in 2 cups of chocolate chips (or chunks) into the dough.
-
-6. Drop rounded tablespoons of dough onto the prepared baking sheet, spacing them about 2 inches apart.
-
-7. Bake for 10-12 minutes, or until the edges are golden brown. The centers should still be slightly soft.
-
-8. Allow the cookies to cool on the baking sheet for a few minutes before transferring them to a wire rack
-to cool completely.
-
-Enjoy your homemade chocolate chip cookies!
-```
-
-When you set it up like this, it will automatically turn off when you're done using it. Then when you access it again, it will automatically turn back on. This is a great way to save money on GPU instances when you're not using them. If you want a persistent wake-on-use connection to your Ollama instance, you can set up a [connection to your Fly network using WireGuard](https://fly.io/docs/reference/private-networking/#discovering-apps-through-dns-on-a-wireguard-connection). Then you can access your Ollama instance at `http://your-app-name.flycast`.
-
-And that's it!
diff --git a/docs/tutorials/langchainjs.md b/docs/tutorials/langchainjs.md
deleted file mode 100644
index f925869b5..000000000
--- a/docs/tutorials/langchainjs.md
+++ /dev/null
@@ -1,77 +0,0 @@
-# Using LangChain with Ollama using JavaScript
-
-In this tutorial, we are going to use JavaScript with LangChain and Ollama to learn about something just a touch more recent. In August 2023, there was a series of wildfires on Maui. There is no way an LLM trained before that time can know about this, since their training data would not include anything as recent as that. So we can find the [Wikipedia article about the fires](https://en.wikipedia.org/wiki/2023_Hawaii_wildfires) and ask questions about the contents.
-
-To get started, let's just use **LangChain** to ask a simple question to a model. To do this with JavaScript, we need to install **LangChain**:
-
-```bash
-npm install @langchain/community
-```
-
-Now we can start building out our JavaScript:
-
-```javascript
-import { Ollama } from "@langchain/community/llms/ollama";
-
-const ollama = new Ollama({
-  baseUrl: "http://localhost:11434",
-  model: "llama3.1",
-});
-
-const answer = await ollama.invoke(`why is the sky blue?`);
-
-console.log(answer);
-```
-
-That will get us the same thing as if we ran `ollama run llama3.1 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
-
-```bash
-npm install cheerio
-```
-
-```javascript
-import { CheerioWebBaseLoader } from "langchain/document_loaders/web/cheerio";
-
-const loader = new CheerioWebBaseLoader("https://en.wikipedia.org/wiki/2023_Hawaii_wildfires");
-const data = await loader.load();
-```
-
-That will load the document. Although this page is smaller than the Odyssey, it is certainly bigger than the context size for most LLMs. So we are going to need to split into smaller pieces, and then select just the pieces relevant to our question. This is a great use for a vector datastore. In this example, we will use the **MemoryVectorStore** that is part of **LangChain**. But there is one more thing we need to get the content into the datastore. We have to run an embeddings process that converts the tokens in the text into a series of vectors. And for that, we are going to use **Tensorflow**. There is a lot of stuff going on in this one. First, install the **Tensorflow** components that we need.
-
-```javascript
-npm install @tensorflow/tfjs-core@3.6.0 @tensorflow/tfjs-converter@3.6.0 @tensorflow-models/universal-sentence-encoder@1.3.3 @tensorflow/tfjs-node@4.10.0
-```
-
-If you just install those components without the version numbers, it will install the latest versions, but there are conflicts within **Tensorflow**, so you need to install the compatible versions.
-
-```javascript
-import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"
-import { MemoryVectorStore } from "langchain/vectorstores/memory";
-import "@tensorflow/tfjs-node";
-import { TensorFlowEmbeddings } from "langchain/embeddings/tensorflow";
-
-// Split the text into 500 character chunks. And overlap each chunk by 20 characters
-const textSplitter = new RecursiveCharacterTextSplitter({
- chunkSize: 500,
- chunkOverlap: 20
-});
-const splitDocs = await textSplitter.splitDocuments(data);
-
-// Then use the TensorFlow Embedding to store these chunks in the datastore
-const vectorStore = await MemoryVectorStore.fromDocuments(splitDocs, new TensorFlowEmbeddings());
-```
-
-To connect the datastore to a question asked to a LLM, we need to use the concept at the heart of **LangChain**: the chain. Chains are a way to connect a number of activities together to accomplish a particular tasks. There are a number of chain types available, but for this tutorial we are using the **RetrievalQAChain**.
-
-```javascript
-import { RetrievalQAChain } from "langchain/chains";
-
-const retriever = vectorStore.asRetriever();
-const chain = RetrievalQAChain.fromLLM(ollama, retriever);
-const result = await chain.call({query: "When was Hawaii's request for a major disaster declaration approved?"});
-console.log(result.text)
-```
-
-So we created a retriever, which is a way to return the chunks that match a query from a datastore. And then connect the retriever and the model via a chain. Finally, we send a query to the chain, which results in an answer using our document as a source. The answer it returned was correct, August 10, 2023.
-
-And that is a simple introduction to what you can do with **LangChain** and **Ollama.**
diff --git a/docs/tutorials/langchainpy.md b/docs/tutorials/langchainpy.md
deleted file mode 100644
index 06543a070..000000000
--- a/docs/tutorials/langchainpy.md
+++ /dev/null
@@ -1,85 +0,0 @@
-# Using LangChain with Ollama in Python
-
-Let's imagine we are studying the classics, such as **the Odyssey** by **Homer**. We might have a question about Neleus and his family. If you ask llama2 for that info, you may get something like:
-
-> I apologize, but I'm a large language model, I cannot provide information on individuals or families that do not exist in reality. Neleus is not a real person or character, and therefore does not have a family or any other personal details. My apologies for any confusion. Is there anything else I can help you with?
-
-This sounds like a typical censored response, but even llama2-uncensored gives a mediocre answer:
-
-> Neleus was a legendary king of Pylos and the father of Nestor, one of the Argonauts. His mother was Clymene, a sea nymph, while his father was Neptune, the god of the sea.
-
-So let's figure out how we can use **LangChain** with Ollama to ask our question to the actual document, the Odyssey by Homer, using Python.
-
-Let's start by asking a simple question that we can get an answer to from the **Llama2** model using **Ollama**. First, we need to install the **LangChain** package:
-
-`pip install langchain_community`
-
-Then we can create a model and ask the question:
-
-```python
-from langchain_community.llms import Ollama
-ollama = Ollama(
-    base_url='http://localhost:11434',
-    model="llama3"
-)
-print(ollama.invoke("why is the sky blue"))
-```
-
-Notice that we are defining the model and the base URL for Ollama.
-
-Now let's load a document to ask questions against. I'll load up the Odyssey by Homer, which you can find at Project Gutenberg. We will need **WebBaseLoader** which is part of **LangChain** and loads text from any webpage. On my machine, I also needed to install **bs4** to get that to work, so run `pip install bs4`.
-
-```python
-from langchain.document_loaders import WebBaseLoader
-loader = WebBaseLoader("https://www.gutenberg.org/files/1727/1727-h/1727-h.htm")
-data = loader.load()
-```
-
-This file is pretty big. Just the preface is 3000 tokens. Which means the full document won't fit into the context for the model. So we need to split it up into smaller pieces.
-
-```python
-from langchain.text_splitter import RecursiveCharacterTextSplitter
-
-text_splitter=RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
-all_splits = text_splitter.split_documents(data)
-```
-
-It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
-We also need to pull embedding model: `ollama pull nomic-embed-text`
-```python
-from langchain.embeddings import OllamaEmbeddings
-from langchain.vectorstores import Chroma
-oembed = OllamaEmbeddings(base_url="http://localhost:11434", model="nomic-embed-text")
-vectorstore = Chroma.from_documents(documents=all_splits, embedding=oembed)
-```
-
-Now let's ask a question from the document. **Who was Neleus, and who is in his family?** Neleus is a character in the Odyssey, and the answer can be found in our text.
-
-```python
-question="Who is Neleus and who is in Neleus' family?"
-docs = vectorstore.similarity_search(question)
-len(docs)
-```
-
-This will output the number of matches for chunks of data similar to the search.
-
-The next thing is to send the question and the relevant parts of the docs to the model to see if we can get a good answer. But we are stitching two parts of the process together, and that is called a chain. This means we need to define a chain:
-
-```python
-from langchain.chains import RetrievalQA
-qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
-res = qachain.invoke({"query": question})
-print(res['result'])
-```
-
-The answer received from this chain was:
-
-> Neleus is a character in Homer's "Odyssey" and is mentioned in the context of Penelope's suitors. Neleus is the father of Chloris, who is married to Neleus and bears him several children, including Nestor, Chromius, Periclymenus, and Pero. Amphinomus, the son of Nisus, is also mentioned as a suitor of Penelope and is known for his good natural disposition and agreeable conversation.
-
-It's not a perfect answer, as it implies Neleus married his daughter when actually Chloris "was the youngest daughter to Amphion son of Iasus and king of Minyan Orchomenus, and was Queen in Pylos".
-
-I updated the chunk_overlap for the text splitter to 20 and tried again and got a much better answer:
-
-> Neleus is a character in Homer's epic poem "The Odyssey." He is the husband of Chloris, who is the youngest daughter of Amphion son of Iasus and king of Minyan Orchomenus. Neleus has several children with Chloris, including Nestor, Chromius, Periclymenus, and Pero.
-
-And that is a much better answer.
diff --git a/docs/tutorials/nvidia-jetson.md b/docs/tutorials/nvidia-jetson.md
deleted file mode 100644
index bb77c486f..000000000
--- a/docs/tutorials/nvidia-jetson.md
+++ /dev/null
@@ -1,15 +0,0 @@
-# Running Ollama on NVIDIA Jetson Devices
-
-Ollama runs well on [NVIDIA Jetson Devices](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/) and should run out of the box with the standard installation instructions. 
-
-The following has been tested on [JetPack 5.1.2](https://developer.nvidia.com/embedded/jetpack), but should also work on JetPack 6.0.
-
-- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
-- Pull the model you want to use (e.g. mistral): `ollama pull mistral`
-- Start an interactive session: `ollama run mistral`
-
-And that's it!
-
-# Running Ollama in Docker
-
-When running GPU accelerated applications in Docker, it is highly recommended to use [dusty-nv jetson-containers repo](https://github.com/dusty-nv/jetson-containers).
\ No newline at end of file
diff --git a/docs/windows.md b/docs/windows.md
index 372a35aa8..80bebed47 100644
--- a/docs/windows.md
+++ b/docs/windows.md
@@ -1,22 +1,15 @@
-# Ollama Windows Preview
+# Ollama Windows
 
-Welcome to the Ollama Windows preview.
+Welcome to Ollama for Windows.
 
 No more WSL required!
 
 Ollama now runs as a native Windows application, including NVIDIA and AMD Radeon GPU support.
-After installing Ollama Windows Preview, Ollama will run in the background and
+After installing Ollama for Windows, Ollama will run in the background and
 the `ollama` command line is available in `cmd`, `powershell` or your favorite
 terminal application. As usual the Ollama [api](./api.md) will be served on
 `http://localhost:11434`.
 
-As this is a preview release, you should expect a few bugs here and there.  If
-you run into a problem you can reach out on
-[Discord](https://discord.gg/ollama), or file an
-[issue](https://github.com/ollama/ollama/issues).
-Logs will often be helpful in diagnosing the problem (see
-[Troubleshooting](#troubleshooting) below)
-
 ## System Requirements
 
 * Windows 10 22H2 or newer, Home or Pro
@@ -25,19 +18,41 @@ Logs will often be helpful in diagnosing the problem (see
 
 Ollama uses unicode characters for progress indication, which may render as unknown squares in some older terminal fonts in Windows 10. If you see this, try changing your terminal font settings.
 
+## Filesystem Requirements
+
+The Ollama install does not require Administrator, and installs in your home directory by default.  You'll need at least 4GB of space for the binary install.  Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size.  If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
+
+### Changing Install Location
+
+To install the Ollama application in a location different than your home directory, start the installer with the following flag
+
+```powershell
+OllamaSetup.exe /DIR="d:\some\location"
+```
+
+### Changing Model Location
+
+To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
+
+1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
+
+2. Click on _Edit environment variables for your account_.
+
+3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
+
+4. Click OK/Apply to save.
+
+If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
+
 ## API Access
 
 Here's a quick example showing API access from `powershell`
 ```powershell
-(Invoke-WebRequest -method POST -Body '{"model":"llama3.1", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
+(Invoke-WebRequest -method POST -Body '{"model":"llama3.2", "prompt":"Why is the sky blue?", "stream": false}' -uri http://localhost:11434/api/generate ).Content | ConvertFrom-json
 ```
 
 ## Troubleshooting
 
-While we're in preview, `OLLAMA_DEBUG` is always enabled, which adds
-a "view logs" menu item to the app, and increases logging for the GUI app and
-server.
-
 Ollama on Windows stores files in a few different locations.  You can view them in
 the explorer window by hitting `+R` and type in:
 - `explorer %LOCALAPPDATA%\Ollama` contains logs, and downloaded updates
@@ -52,6 +67,10 @@ the explorer window by hitting `+R` and type in:
 
 The Ollama Windows installer registers an Uninstaller application.  Under `Add or remove programs` in Windows Settings, you can uninstall Ollama.
 
+> [!NOTE]
+> If you have [changed the OLLAMA_MODELS location](#changing-model-location), the installer will not remove your downloaded models
+
+
 ## Standalone CLI
 
 The easiest way to install Ollama on Windows is to use the `OllamaSetup.exe`
@@ -64,3 +83,6 @@ If you'd like to install or integrate Ollama as a service, a standalone
 and GPU library dependencies for Nvidia and AMD. This allows for embedding
 Ollama in existing applications, or running it as a system service via `ollama
 serve` with tools such as [NSSM](https://nssm.cc/).
+
+> [!NOTE]  
+> If you are upgrading from a prior version, you should remove the old directories first.
diff --git a/envconfig/config.go b/envconfig/config.go
index 239c49fe6..831a3a730 100644
--- a/envconfig/config.go
+++ b/envconfig/config.go
@@ -72,6 +72,7 @@ func Origins() (origins []string) {
 		"app://*",
 		"file://*",
 		"tauri://*",
+		"vscode-webview://*",
 	)
 
 	return origins
@@ -152,6 +153,8 @@ var (
 	Debug = Bool("OLLAMA_DEBUG")
 	// FlashAttention enables the experimental flash attention feature.
 	FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
+	// KvCacheType is the quantization type for the K/V cache.
+	KvCacheType = String("OLLAMA_KV_CACHE_TYPE")
 	// NoHistory disables readline history.
 	NoHistory = Bool("OLLAMA_NOHISTORY")
 	// NoPrune disables pruning of model blobs on startup.
@@ -160,6 +163,8 @@ var (
 	SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
 	// IntelGPU enables experimental Intel GPU detection.
 	IntelGPU = Bool("OLLAMA_INTEL_GPU")
+	// MultiUserCache optimizes prompt caching for multi-user scenarios
+	MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
 )
 
 func String(s string) func() string {
@@ -170,7 +175,6 @@ func String(s string) func() string {
 
 var (
 	LLMLibrary = String("OLLAMA_LLM_LIBRARY")
-	TmpDir     = String("OLLAMA_TMPDIR")
 
 	CudaVisibleDevices    = String("CUDA_VISIBLE_DEVICES")
 	HipVisibleDevices     = String("HIP_VISIBLE_DEVICES")
@@ -232,6 +236,7 @@ func AsMap() map[string]EnvVar {
 	ret := map[string]EnvVar{
 		"OLLAMA_DEBUG":             {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
 		"OLLAMA_FLASH_ATTENTION":   {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
+		"OLLAMA_KV_CACHE_TYPE":     {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
 		"OLLAMA_GPU_OVERHEAD":      {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
 		"OLLAMA_HOST":              {"OLLAMA_HOST", Host(), "IP Address for the ollama server (default 127.0.0.1:11434)"},
 		"OLLAMA_KEEP_ALIVE":        {"OLLAMA_KEEP_ALIVE", KeepAlive(), "The duration that models stay loaded in memory (default \"5m\")"},
@@ -245,7 +250,7 @@ func AsMap() map[string]EnvVar {
 		"OLLAMA_NUM_PARALLEL":      {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
 		"OLLAMA_ORIGINS":           {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
 		"OLLAMA_SCHED_SPREAD":      {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
-		"OLLAMA_TMPDIR":            {"OLLAMA_TMPDIR", TmpDir(), "Location for temporary files"},
+		"OLLAMA_MULTIUSER_CACHE":   {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
 
 		// Informational
 		"HTTP_PROXY":  {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
@@ -262,10 +267,10 @@ func AsMap() map[string]EnvVar {
 
 	if runtime.GOOS != "darwin" {
 		ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"}
-		ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible"}
-		ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible"}
-		ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which VK AMD devices are visible"}
-		ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible"}
+		ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"}
+		ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"}
+		ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which VK AMD devices are visible by numeric ID"}
+		ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
 		ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
 		ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
 	}
diff --git a/envconfig/config_test.go b/envconfig/config_test.go
index 7ac7c53e3..735b45405 100644
--- a/envconfig/config_test.go
+++ b/envconfig/config_test.go
@@ -68,6 +68,7 @@ func TestOrigins(t *testing.T) {
 			"app://*",
 			"file://*",
 			"tauri://*",
+			"vscode-webview://*",
 		}},
 		{"http://10.0.0.1", []string{
 			"http://10.0.0.1",
@@ -86,6 +87,7 @@ func TestOrigins(t *testing.T) {
 			"app://*",
 			"file://*",
 			"tauri://*",
+			"vscode-webview://*",
 		}},
 		{"http://172.16.0.1,https://192.168.0.1", []string{
 			"http://172.16.0.1",
@@ -105,6 +107,7 @@ func TestOrigins(t *testing.T) {
 			"app://*",
 			"file://*",
 			"tauri://*",
+			"vscode-webview://*",
 		}},
 		{"http://totally.safe,http://definitely.legit", []string{
 			"http://totally.safe",
@@ -124,6 +127,7 @@ func TestOrigins(t *testing.T) {
 			"app://*",
 			"file://*",
 			"tauri://*",
+			"vscode-webview://*",
 		}},
 	}
 	for _, tt := range cases {
diff --git a/examples/README.md b/examples/README.md
index b10a34914..7f349f727 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,3 +1,14 @@
 # Examples
 
 This directory contains different examples of using Ollama.
+
+## Python examples
+Ollama Python examples at [ollama-python/examples](https://github.com/ollama/ollama-python/tree/main/examples)
+
+
+## JavaScript examples
+Ollama JavaScript examples at [ollama-js/examples](https://github.com/ollama/ollama-js/tree/main/examples)
+
+
+## OpenAI compatibility examples
+Ollama OpenAI compatibility examples at [ollama/examples/openai](../docs/openai.md)
diff --git a/examples/go-chat/main.go b/examples/go-chat/main.go
index 7663fb8f4..074303057 100644
--- a/examples/go-chat/main.go
+++ b/examples/go-chat/main.go
@@ -35,7 +35,7 @@ func main() {
 
 	ctx := context.Background()
 	req := &api.ChatRequest{
-		Model:    "llama3.1",
+		Model:    "llama3.2",
 		Messages: messages,
 	}
 
diff --git a/examples/langchain-python-rag-document/README.md b/examples/langchain-python-rag-document/README.md
index e2f3bc028..d37afc9d0 100644
--- a/examples/langchain-python-rag-document/README.md
+++ b/examples/langchain-python-rag-document/README.md
@@ -4,10 +4,10 @@ This example provides an interface for asking questions to a PDF document.
 
 ## Setup
 
-1. Ensure you have the `llama3.1` model installed:
+1. Ensure you have the `llama3.2` model installed:
 
 ```
-ollama pull llama3.1
+ollama pull llama3.2
 ```
 
 2. Install the Python Requirements.
diff --git a/examples/langchain-python-rag-document/main.py b/examples/langchain-python-rag-document/main.py
index 6f7cec9be..b93828f82 100644
--- a/examples/langchain-python-rag-document/main.py
+++ b/examples/langchain-python-rag-document/main.py
@@ -1,8 +1,8 @@
-from langchain.document_loaders import OnlinePDFLoader
-from langchain.vectorstores import Chroma
-from langchain.embeddings import GPT4AllEmbeddings
-from langchain import PromptTemplate
-from langchain.llms import Ollama
+from langchain_community.document_loaders import OnlinePDFLoader
+from langchain_community.vectorstores import Chroma
+from langchain_community.embeddings import GPT4AllEmbeddings
+from langchain_core.prompts import PromptTemplate
+from langchain_community.llms import Ollama
 from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain.chains import RetrievalQA
@@ -51,7 +51,7 @@ while True:
         template=template,
     )
 
-    llm = Ollama(model="llama3.1", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
+    llm = Ollama(model="llama3.2", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
     qa_chain = RetrievalQA.from_chain_type(
         llm,
         retriever=vectorstore.as_retriever(),
diff --git a/examples/langchain-python-rag-websummary/README.md b/examples/langchain-python-rag-websummary/README.md
index 29c706a39..746c47ab4 100644
--- a/examples/langchain-python-rag-websummary/README.md
+++ b/examples/langchain-python-rag-websummary/README.md
@@ -4,10 +4,10 @@ This example summarizes the website, [https://ollama.com/blog/run-llama2-uncenso
 
 ## Running the Example
 
-1. Ensure you have the `llama3.1` model installed:
+1. Ensure you have the `llama3.2` model installed:
 
    ```bash
-   ollama pull llama3.1
+   ollama pull llama3.2
    ```
 
 2. Install the Python Requirements.
diff --git a/examples/langchain-python-rag-websummary/main.py b/examples/langchain-python-rag-websummary/main.py
index 77b09fbbc..56f8bd24e 100644
--- a/examples/langchain-python-rag-websummary/main.py
+++ b/examples/langchain-python-rag-websummary/main.py
@@ -5,7 +5,7 @@ from langchain.chains.summarize import load_summarize_chain
 loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
 docs = loader.load()
 
-llm = Ollama(model="llama3.1")
+llm = Ollama(model="llama3.2")
 chain = load_summarize_chain(llm, chain_type="stuff")
 
 result = chain.invoke(docs)
diff --git a/examples/langchain-python-simple/README.md b/examples/langchain-python-simple/README.md
index 60db2c8c3..680ab560f 100644
--- a/examples/langchain-python-simple/README.md
+++ b/examples/langchain-python-simple/README.md
@@ -4,10 +4,10 @@ This example is a basic "hello world" of using LangChain with Ollama.
 
 ## Running the Example
 
-1. Ensure you have the `llama3.1` model installed:
+1. Ensure you have the `llama3.2` model installed:
 
    ```bash
-   ollama pull llama3.1
+   ollama pull llama3.2
    ```
 
 2. Install the Python Requirements.
diff --git a/examples/langchain-python-simple/main.py b/examples/langchain-python-simple/main.py
index a7ed81d67..dafff8275 100644
--- a/examples/langchain-python-simple/main.py
+++ b/examples/langchain-python-simple/main.py
@@ -1,6 +1,6 @@
 from langchain.llms import Ollama
 
-input = input("What is your question?")
-llm = Ollama(model="llama3.1")
-res = llm.predict(input)
+input = input("What is your question?\n> ")
+llm = Ollama(model="llama3.2")
+res = llm.invoke(input)
 print (res)
diff --git a/examples/modelfile-mario/Modelfile b/examples/modelfile-mario/Modelfile
index a37470864..b8e49667a 100644
--- a/examples/modelfile-mario/Modelfile
+++ b/examples/modelfile-mario/Modelfile
@@ -1,4 +1,4 @@
-FROM llama3.1
+FROM llama3.2
 PARAMETER temperature 1
 SYSTEM """
 You are Mario from super mario bros, acting as an assistant.
diff --git a/examples/modelfile-mario/readme.md b/examples/modelfile-mario/readme.md
index c3f34197a..882023add 100644
--- a/examples/modelfile-mario/readme.md
+++ b/examples/modelfile-mario/readme.md
@@ -2,12 +2,12 @@
 
 # Example character: Mario
 
-This example shows how to create a basic character using Llama3.1 as the base model.
+This example shows how to create a basic character using Llama 3.2 as the base model.
 
 To run this example:
 
 1. Download the Modelfile
-2. `ollama pull llama3.1` to get the base model used in the model file.
+2. `ollama pull llama3.2` to get the base model used in the model file.
 3. `ollama create NAME -f ./Modelfile`
 4. `ollama run NAME`
 
@@ -18,7 +18,7 @@ Ask it some questions like "Who are you?" or "Is Peach in trouble again?"
 What the model file looks like:
 
 ```
-FROM llama3.1
+FROM llama3.2
 PARAMETER temperature 1
 SYSTEM """
 You are Mario from Super Mario Bros, acting as an assistant.
diff --git a/examples/python-grounded-factuality-rag-check/README.md b/examples/python-grounded-factuality-rag-check/README.md
index 5c9817526..868b16230 100644
--- a/examples/python-grounded-factuality-rag-check/README.md
+++ b/examples/python-grounded-factuality-rag-check/README.md
@@ -1,14 +1,14 @@
 # RAG Hallucination Checker using Bespoke-Minicheck
 
-This example allows the user to ask questions related to a document, which can be specified via an article url. Relevant chunks are retreived from the document and given to `llama3.1` as context to answer the question. Then each sentence in the answer is checked against the retrieved chunks using `bespoke-minicheck` to ensure that the answer does not contain hallucinations. 
+This example allows the user to ask questions related to a document, which can be specified via an article url. Relevant chunks are retrieved from the document and given to `llama3.2` as context to answer the question. Then each sentence in the answer is checked against the retrieved chunks using `bespoke-minicheck` to ensure that the answer does not contain hallucinations.
 
 ## Running the Example
 
-1. Ensure `all-minilm` (embedding) `llama3.1` (chat) and `bespoke-minicheck` (check) models installed:
+1. Ensure `all-minilm` (embedding) `llama3.2` (chat) and `bespoke-minicheck` (check) models installed:
 
    ```bash
    ollama pull all-minilm
-   ollama pull llama3.1
+   ollama pull llama3.2
    ollama pull bespoke-minicheck
    ```
 
diff --git a/examples/python-grounded-factuality-rag-check/main.py b/examples/python-grounded-factuality-rag-check/main.py
index f4d562d5f..dd18f3efa 100644
--- a/examples/python-grounded-factuality-rag-check/main.py
+++ b/examples/python-grounded-factuality-rag-check/main.py
@@ -115,11 +115,11 @@ if __name__ == "__main__":
 
         print(f"\nRetrieved chunks: \n{sourcetext}\n")
 
-        # Give the retreived chunks and question to the chat model
+        # Give the retrieved chunks and question to the chat model
         system_prompt = f"Only use the following information to answer the question. Do not use anything else: {sourcetext}"
 
         ollama_response = ollama.generate(
-            model="llama3.1",
+            model="llama3.2",
             prompt=question,
             system=system_prompt,
             options={"stream": False},
diff --git a/examples/python-json-datagenerator/predefinedschema.py b/examples/python-json-datagenerator/predefinedschema.py
index 68090ad79..914637603 100644
--- a/examples/python-json-datagenerator/predefinedschema.py
+++ b/examples/python-json-datagenerator/predefinedschema.py
@@ -2,7 +2,7 @@ import requests
 import json
 import random
 
-model = "llama3.1"
+model = "llama3.2"
 template = {
   "firstName": "",
   "lastName": "",
diff --git a/examples/python-json-datagenerator/randomaddresses.py b/examples/python-json-datagenerator/randomaddresses.py
index 878c98037..3df59d320 100644
--- a/examples/python-json-datagenerator/randomaddresses.py
+++ b/examples/python-json-datagenerator/randomaddresses.py
@@ -12,7 +12,7 @@ countries = [
     "France",
 ]
 country = random.choice(countries)
-model = "llama3.1"
+model = "llama3.2"
 
 prompt = f"generate one realistically believable sample data set of a persons first name, last name, address in {country}, and phone number. Do not use common names. Respond using JSON. Key names should have no backslashes, values should use plain ascii with no special characters."
 
diff --git a/examples/python-json-datagenerator/readme.md b/examples/python-json-datagenerator/readme.md
index 5b444dff1..a551e1dd5 100644
--- a/examples/python-json-datagenerator/readme.md
+++ b/examples/python-json-datagenerator/readme.md
@@ -6,10 +6,10 @@ There are two python scripts in this example. `randomaddresses.py` generates ran
 
 ## Running the Example
 
-1. Ensure you have the `llama3.1` model installed:
+1. Ensure you have the `llama3.2` model installed:
 
    ```bash
-   ollama pull llama3.1
+   ollama pull llama3.2
    ```
 
 2. Install the Python Requirements.
diff --git a/examples/python-simplechat/client.py b/examples/python-simplechat/client.py
index 85043d5f4..6ef14ffc2 100644
--- a/examples/python-simplechat/client.py
+++ b/examples/python-simplechat/client.py
@@ -2,7 +2,7 @@ import json
 import requests
 
 # NOTE: ollama must be running for this to work, start the ollama app or run `ollama serve`
-model = "llama3.1"  # TODO: update this for whatever model you wish to use
+model = "llama3.2"  # TODO: update this for whatever model you wish to use
 
 
 def chat(messages):
diff --git a/examples/python-simplechat/readme.md b/examples/python-simplechat/readme.md
index 4c2ded4d8..a4a2dfc1e 100644
--- a/examples/python-simplechat/readme.md
+++ b/examples/python-simplechat/readme.md
@@ -4,10 +4,10 @@ The **chat** endpoint is one of two ways to generate text from an LLM with Ollam
 
 ## Running the Example
 
-1. Ensure you have the `llama3.1` model installed:
+1. Ensure you have the `llama3.2` model installed:
 
    ```bash
-   ollama pull llama3.1
+   ollama pull llama3.2
    ```
 
 2. Install the Python Requirements.
diff --git a/examples/typescript-simplechat/client.ts b/examples/typescript-simplechat/client.ts
index 8ad113b12..d8faaa1b5 100644
--- a/examples/typescript-simplechat/client.ts
+++ b/examples/typescript-simplechat/client.ts
@@ -1,6 +1,6 @@
 import * as readline from "readline";
 
-const model = "llama3.1";
+const model = "llama3.2";
 type Message = {
   role: "assistant" | "user" | "system";
   content: string;
diff --git a/go.mod b/go.mod
index 6e437c730..1a1fdb408 100644
--- a/go.mod
+++ b/go.mod
@@ -1,34 +1,35 @@
 module github.com/ollama/ollama
 
-go 1.22.5
+go 1.23.4
 
 require (
 	github.com/containerd/console v1.0.3
-	github.com/emirpasic/gods v1.18.1
 	github.com/gin-gonic/gin v1.10.0
 	github.com/golang/protobuf v1.5.4 // indirect
-	github.com/google/uuid v1.1.2
+	github.com/google/uuid v1.6.0
 	github.com/olekukonko/tablewriter v0.0.5
 	github.com/spf13/cobra v1.7.0
 	github.com/stretchr/testify v1.9.0
 	github.com/x448/float16 v0.8.4
-	golang.org/x/sync v0.3.0
+	golang.org/x/sync v0.10.0
 )
 
 require (
 	github.com/agnivade/levenshtein v1.1.1
 	github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
+	github.com/emirpasic/gods/v2 v2.0.0-alpha
 	github.com/google/go-cmp v0.6.0
 	github.com/mattn/go-runewidth v0.0.14
 	github.com/nlpodyssey/gopickle v0.3.0
 	github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
+	golang.org/x/image v0.22.0
 )
 
 require (
 	github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
 	github.com/bytedance/sonic/loader v0.1.1 // indirect
 	github.com/chewxy/hm v1.0.0 // indirect
-	github.com/chewxy/math32 v1.10.1 // indirect
+	github.com/chewxy/math32 v1.11.0 // indirect
 	github.com/cloudwego/base64x v0.1.4 // indirect
 	github.com/cloudwego/iasm v0.2.0 // indirect
 	github.com/davecgh/go-spew v1.1.1 // indirect
@@ -67,12 +68,12 @@ require (
 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 	github.com/ugorji/go/codec v1.2.12 // indirect
 	golang.org/x/arch v0.8.0 // indirect
-	golang.org/x/crypto v0.23.0
+	golang.org/x/crypto v0.31.0
 	golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
 	golang.org/x/net v0.25.0 // indirect
-	golang.org/x/sys v0.20.0
-	golang.org/x/term v0.20.0
-	golang.org/x/text v0.15.0
+	golang.org/x/sys v0.28.0
+	golang.org/x/term v0.27.0
+	golang.org/x/text v0.21.0
 	google.golang.org/protobuf v1.34.1
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )
diff --git a/go.sum b/go.sum
index 926ed26d8..6a2c91896 100644
--- a/go.sum
+++ b/go.sum
@@ -21,8 +21,8 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA
 github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
 github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
 github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
-github.com/chewxy/math32 v1.10.1 h1:LFpeY0SLJXeaiej/eIp2L40VYfscTvKh/FSEZ68uMkU=
-github.com/chewxy/math32 v1.10.1/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
+github.com/chewxy/math32 v1.11.0 h1:8sek2JWqeaKkVnHa7bPVqCEOUPbARo4SGxs6toKyAOo=
+github.com/chewxy/math32 v1.11.0/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
 github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
 github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
 github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
@@ -42,8 +42,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
 github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
-github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
-github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
+github.com/emirpasic/gods/v2 v2.0.0-alpha h1:dwFlh8pBg1VMOXWGipNMRt8v96dKAIvBehtCt6OtunU=
+github.com/emirpasic/gods/v2 v2.0.0-alpha/go.mod h1:W0y4M2dtBB9U5z3YlghmpuUhiaZT2h6yoeE+C1sCp6A=
 github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
 github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@@ -113,8 +113,9 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
 github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
 github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
+github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
 github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@@ -211,8 +212,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
-golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
+golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
+golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
 golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -230,6 +231,8 @@ golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+o
 golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
+golang.org/x/image v0.22.0 h1:UtK5yLUzilVrkjMAZAZ34DXGpASN8i8pj8g+O+yd10g=
+golang.org/x/image v0.22.0/go.mod h1:9hPFhljd4zZ1GNSIZJ49sqbp45GKK9t6w+iXvGqZUz4=
 golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
 golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
@@ -263,8 +266,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
-golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
+golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
+golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -280,17 +283,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
-golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
+golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
-golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
+golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
+golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
-golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
+golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
+golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
 golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/gpu/gpu_linux.go b/gpu/gpu_linux.go
deleted file mode 100644
index 1251c6e8e..000000000
--- a/gpu/gpu_linux.go
+++ /dev/null
@@ -1,110 +0,0 @@
-package gpu
-
-import (
-	"bufio"
-	"fmt"
-	"os"
-	"strings"
-
-	"github.com/ollama/ollama/format"
-)
-
-var CudartGlobs = []string{
-	"/usr/local/cuda/lib64/libcudart.so*",
-	"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
-	"/usr/lib/x86_64-linux-gnu/libcudart.so*",
-	"/usr/lib/wsl/lib/libcudart.so*",
-	"/usr/lib/wsl/drivers/*/libcudart.so*",
-	"/opt/cuda/lib64/libcudart.so*",
-	"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
-	"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
-	"/usr/lib/aarch64-linux-gnu/libcudart.so*",
-	"/usr/local/cuda/lib*/libcudart.so*",
-	"/usr/lib*/libcudart.so*",
-	"/usr/local/lib*/libcudart.so*",
-}
-
-var NvmlGlobs = []string{}
-
-var NvcudaGlobs = []string{
-	"/usr/local/cuda*/targets/*/lib/libcuda.so*",
-	"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
-	"/usr/lib/*-linux-gnu/libcuda.so*",
-	"/usr/lib/wsl/lib/libcuda.so*",
-	"/usr/lib/wsl/drivers/*/libcuda.so*",
-	"/opt/cuda/lib*/libcuda.so*",
-	"/usr/local/cuda/lib*/libcuda.so*",
-	"/usr/lib*/libcuda.so*",
-	"/usr/local/lib*/libcuda.so*",
-}
-
-var OneapiGlobs = []string{
-	"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
-	"/usr/lib*/libze_intel_gpu.so*",
-}
-
-var (
-	CudartMgmtName = "libcudart.so*"
-	NvcudaMgmtName = "libcuda.so*"
-	NvmlMgmtName   = "" // not currently wired on linux
-	OneapiMgmtName = "libze_intel_gpu.so*"
-	VulkanMgmtName = "libvulkan.so*"
-	libcapMgmtName = "libcap.so*"
-)
-
-var VulkanGlobs = []string{
-	"/usr/lib/x86_64-linux-gnu/libvulkan.so*",
-	"/usr/lib/aarch64-linux-gnu/libvulkan.so*",
-	"/usr/lib*/libvulkan.so*",
-}
-
-var capLinuxGlobs = []string{
-	"/usr/lib/x86_64-linux-gnu/libvulkan.so*",
-	"/usr/lib/aarch64-linux-gnu/libcap.so*",
-	"/usr/lib*/libcap.so*",
-}
-
-func FindLibCapLibs() []string {
-	return FindGPULibs(libcapMgmtName, capLinuxGlobs)
-}
-
-func GetCPUMem() (memInfo, error) {
-	var mem memInfo
-	var total, available, free, buffers, cached, freeSwap uint64
-	f, err := os.Open("/proc/meminfo")
-	if err != nil {
-		return mem, err
-	}
-	defer f.Close()
-	s := bufio.NewScanner(f)
-	for s.Scan() {
-		line := s.Text()
-		switch {
-		case strings.HasPrefix(line, "MemTotal:"):
-			_, err = fmt.Sscanf(line, "MemTotal:%d", &total)
-		case strings.HasPrefix(line, "MemAvailable:"):
-			_, err = fmt.Sscanf(line, "MemAvailable:%d", &available)
-		case strings.HasPrefix(line, "MemFree:"):
-			_, err = fmt.Sscanf(line, "MemFree:%d", &free)
-		case strings.HasPrefix(line, "Buffers:"):
-			_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
-		case strings.HasPrefix(line, "Cached:"):
-			_, err = fmt.Sscanf(line, "Cached:%d", &cached)
-		case strings.HasPrefix(line, "SwapFree:"):
-			_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
-		default:
-			continue
-		}
-		if err != nil {
-			return mem, err
-		}
-	}
-	mem.TotalMemory = total * format.KibiByte
-	mem.FreeSwap = freeSwap * format.KibiByte
-	if available > 0 {
-		mem.FreeMemory = available * format.KibiByte
-	} else {
-		mem.FreeMemory = (free + buffers + cached) * format.KibiByte
-	}
-	return mem, nil
-}
diff --git a/gpu/gpu_windows.go b/gpu/gpu_windows.go
deleted file mode 100644
index f9afb6ed8..000000000
--- a/gpu/gpu_windows.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package gpu
-
-import (
-	"fmt"
-	"syscall"
-	"unsafe"
-)
-
-type MEMORYSTATUSEX struct {
-	length               uint32
-	MemoryLoad           uint32
-	TotalPhys            uint64
-	AvailPhys            uint64
-	TotalPageFile        uint64
-	AvailPageFile        uint64
-	TotalVirtual         uint64
-	AvailVirtual         uint64
-	AvailExtendedVirtual uint64
-}
-
-var (
-	k32                      = syscall.NewLazyDLL("kernel32.dll")
-	globalMemoryStatusExProc = k32.NewProc("GlobalMemoryStatusEx")
-	sizeofMemoryStatusEx     = uint32(unsafe.Sizeof(MEMORYSTATUSEX{}))
-)
-
-var CudartGlobs = []string{
-	"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
-}
-
-var NvmlGlobs = []string{
-	"c:\\Windows\\System32\\nvml.dll",
-}
-
-var NvcudaGlobs = []string{
-	"c:\\windows\\system*\\nvcuda.dll",
-}
-
-var OneapiGlobs = []string{
-	"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
-}
-
-var VulkanGlobs = []string{
-	"c:\\Windows\\System32\\vulkan-1.dll",
-}
-
-var (
-	CudartMgmtName = "cudart64_*.dll"
-	NvcudaMgmtName = "nvcuda.dll"
-	NvmlMgmtName   = "nvml.dll"
-	OneapiMgmtName = "ze_intel_gpu64.dll"
-	VulkanMgmtName = "vulkan-1.dll"
-)
-
-func FindLibCapLibs() []string {
-	return []string{""}
-}
-
-func GetCPUMem() (memInfo, error) {
-	memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
-	r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus)))
-	if r1 == 0 {
-		return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
-	}
-	return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil
-}
diff --git a/integration/basic_test.go b/integration/basic_test.go
index 8e35b5c5b..88d3530e6 100644
--- a/integration/basic_test.go
+++ b/integration/basic_test.go
@@ -30,6 +30,48 @@ func TestOrcaMiniBlueSky(t *testing.T) {
 	GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
 }
 
+func TestUnicode(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
+	defer cancel()
+	// Set up the test data
+	req := api.GenerateRequest{
+		// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
+		Model:  "deepseek-coder-v2:16b-lite-instruct-q2_K",
+		Prompt: "天空为什么是蓝色的?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"temperature": 0,
+			"seed":        123,
+			// Workaround deepseek context shifting bug
+			"num_ctx":     8192,
+			"num_predict": 2048,
+		},
+	}
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
+}
+
+func TestExtendedUnicodeOutput(t *testing.T) {
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+	defer cancel()
+	// Set up the test data
+	req := api.GenerateRequest{
+		Model:  "gemma2:2b",
+		Prompt: "Output some smily face emoji",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"temperature": 0,
+			"seed":        123,
+		},
+	}
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
+}
+
 func TestUnicodeModelDir(t *testing.T) {
 	// This is only useful for Windows with utf-16 characters, so skip this test for other platforms
 	if runtime.GOOS != "windows" {
diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go
index 42e9d0749..78e3b5ab6 100644
--- a/integration/concurrency_test.go
+++ b/integration/concurrency_test.go
@@ -42,7 +42,7 @@ func TestMultiModelConcurrency(t *testing.T) {
 		}
 		resp = [2][]string{
 			{"sunlight"},
-			{"england", "english", "massachusetts", "pilgrims", "british"},
+			{"england", "english", "massachusetts", "pilgrims", "british", "festival"},
 		}
 	)
 	var wg sync.WaitGroup
@@ -60,7 +60,8 @@ func TestMultiModelConcurrency(t *testing.T) {
 	for i := 0; i < len(req); i++ {
 		go func(i int) {
 			defer wg.Done()
-			DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
+			// Note: CPU based inference can crawl so don't give up too quickly
+			DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 30*time.Second)
 		}(i)
 	}
 	wg.Wait()
@@ -206,7 +207,7 @@ func TestMultiModelStress(t *testing.T) {
 		chosenModels = mediumModels
 		// default:
 		// 	slog.Info("selecting large models")
-		// 	chosenModels = largModels
+		// 	chosenModels = largeModels
 	}
 
 	req, resp := GenerateRequests()
@@ -231,7 +232,7 @@ func TestMultiModelStress(t *testing.T) {
 	var wg sync.WaitGroup
 	consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
 	for i := 0; i < len(req); i++ {
-		// Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
+		// Always get at least 2 models, but don't overshoot VRAM too much or we'll take too long
 		if i > 1 && consumed > maxVram {
 			slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
 			break
diff --git a/integration/context_test.go b/integration/context_test.go
index f1342e16c..add41a762 100644
--- a/integration/context_test.go
+++ b/integration/context_test.go
@@ -10,7 +10,38 @@ import (
 	"github.com/ollama/ollama/api"
 )
 
+func TestLongInputContext(t *testing.T) {
+	// Setting NUM_PARALLEL to 1 ensures the allocated context is exactly what
+	// we asked for and there is nothing extra that we could spill over into
+	t.Setenv("OLLAMA_NUM_PARALLEL", "1")
+
+	// Longer needed for small footprint GPUs
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	// Set up the test data
+	req := api.GenerateRequest{
+		Model:  "llama2",
+		Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"temperature": 0,
+			"seed":        123,
+			"num_ctx":     128,
+		},
+	}
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	if err := PullIfMissing(ctx, client, req.Model); err != nil {
+		t.Fatalf("PullIfMissing failed: %v", err)
+	}
+	DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
+}
+
 func TestContextExhaustion(t *testing.T) {
+	// Setting NUM_PARALLEL to 1 ensures the allocated context is exactly what
+	// we asked for and there is nothing extra that we could spill over into
+	t.Setenv("OLLAMA_NUM_PARALLEL", "1")
+
 	// Longer needed for small footprint GPUs
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
 	defer cancel()
diff --git a/integration/embed_test.go b/integration/embed_test.go
index 4a68af68a..8a95816a5 100644
--- a/integration/embed_test.go
+++ b/integration/embed_test.go
@@ -11,12 +11,24 @@ import (
 	"github.com/ollama/ollama/api"
 )
 
-func floatsEqual32(a, b float32) bool {
-	return math.Abs(float64(a-b)) <= 1e-4
+func dotProduct[V float32 | float64](v1, v2 []V) V {
+	var result V = 0
+	for i := 0; i < len(v1); i++ {
+		result += v1[i] * v2[i]
+	}
+	return result
 }
 
-func floatsEqual64(a, b float64) bool {
-	return math.Abs(a-b) <= 1e-4
+func magnitude[V float32 | float64](v []V) V {
+	var result V = 0
+	for _, val := range v {
+		result += val * val
+	}
+	return V(math.Sqrt(float64(result)))
+}
+
+func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
+	return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
 }
 
 func TestAllMiniLMEmbeddings(t *testing.T) {
@@ -38,8 +50,12 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
 	}
 
-	if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
-		t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
+	expected := []float64{
+		0.06642947345972061, -0.01160573959350586, 0.3302811086177826, 0.309552937746048, 0.36223655939102173, 0.05672447010874748, 0.6955016851425171, -0.17069467902183533, 0.8547305464744568, 0.21076075732707977, -0.29339903593063354, -0.05926772207021713, -0.003363408148288727, -0.4204462468624115, -0.1061280220746994, 0.30754348635673523, -0.14551642537117004, -1.0430994033813477, -0.4805174171924591, -0.40448474884033203, -0.4345352053642273, 0.3573606014251709, -0.4098161458969116, 0.25664326548576355, -0.3021087646484375, 0.36236199736595154, -0.23262615501880646, 0.08319848775863647, 0.28042519092559814, -0.052289899438619614, -0.12552005052566528, 0.402255117893219, 0.24357250332832336, 0.08881516754627228, -0.17023836076259613, -0.2868475615978241, 0.4790303707122803, -0.3199635446071625, 0.02826809138059616, -0.19417747855186462, -0.19217649102210999, -0.21705707907676697, -0.1210065633058548, 0.10262420773506165, -0.07726037502288818, 0.10094445943832397, -0.06194962561130524, 0.1712605208158493, 0.628441333770752, -0.10222385078668594, -0.16214007139205933, 0.059920795261859894, -0.5053377151489258, 0.10545563697814941, 0.32686805725097656, 0.7650210857391357, 0.006465774029493332, -0.13403119146823883, 0.6090353727340698, 0.05603303387761116, -0.37635889649391174, 0.45424884557724, -0.5053073763847351, 0.4572359323501587, 0.6084011197090149, -0.3659921884536743, -0.3536888360977173, 0.05569244921207428, -0.4166066646575928, -0.43796032667160034, -0.16600576043128967, 0.12460685521364212, 0.40493422746658325, -0.18632565438747406, 0.2390710711479187, 0.007283639162778854, 0.4001992344856262, -0.4455743134021759, -0.05360018089413643, -0.08401738107204437, 0.2041706144809723, -0.42083415389060974, -0.491476833820343, 0.7860275506973267, 0.08280622214078903, 0.4309011697769165, 0.09778489172458649, 0.3392091989517212, -0.5618907809257507, 0.06766007840633392, -0.05127308890223503, -0.23472431302070618, -0.7611223459243774, -0.20227840542793274, -0.5491426587104797, 0.09030043333768845, 0.37326449155807495, -0.2696656584739685, 0.2814738154411316, 0.1461343765258789, 0.309052437543869, -0.3387487828731537, 0.1990429162979126, 0.0474909171462059, -0.02756538614630699, -0.20544570684432983, 0.5137258768081665, 0.22562497854232788, 0.40487033128738403, 0.04954294115304947, -0.23911823332309723, -0.5578761696815491, 0.14376327395439148, -0.12795016169548035, -0.26285219192504883, 0.3614377975463867, -0.22225692868232727, 0.11940789222717285, -0.6961514353752136, -0.3324243426322937, -0.07613810151815414, 0.24946099519729614, 0.1462409496307373, 0.5309336185455322, 0.051560595631599426, -0.11104149371385574, -0.39189594984054565, -4.767201176712463e-32, 0.892546534538269, -0.07396792620420456, 0.6088366508483887, 0.23729179799556732, 0.2614588737487793, -0.3626874089241028, -0.23131835460662842, -0.024579279124736786, -0.12901946902275085, -0.2306443750858307, -0.0376533679664135, -0.09649471938610077, -0.16013199090957642, -0.31914401054382324, 0.3151017129421234, -0.11264121532440186, -0.4020160734653473, 0.039211247116327286, -0.5478582978248596, 0.5563258528709412, -0.6903842091560364, 0.2746567130088806, -0.24196553230285645, -0.053318753838539124, -0.18611761927604675, -0.28490889072418213, 0.237456813454628, 0.4946249723434448, 0.37237465381622314, 0.07815749943256378, 0.6494859457015991, 0.6915512084960938, -0.14422327280044556, 0.30338582396507263, -0.17378094792366028, -0.33589833974838257, -0.09702004492282867, -0.04210608825087547, -0.566387414932251, 0.18866634368896484, -0.3533778488636017, 0.37286972999572754, -0.39420801401138306, 0.0818595215678215, 0.436712384223938, -0.08886678516864777, 0.2527940273284912, -0.5864061117172241, -0.37891554832458496, 0.21103361248970032, -0.2275354266166687, 0.1558678150177002, 0.09536703675985336, -0.27437490224838257, 0.4484926164150238, 0.20584626495838165, 0.45972558856010437, -0.231113001704216, -0.021833699196577072, 0.3253912925720215, -0.08802174031734467, -0.023067735135555267, 0.33492740988731384, 0.5189340114593506, 0.2481488585472107, -0.07638847082853317, 0.25147074460983276, 0.2771286964416504, -0.08443005383014679, -0.5207436084747314, 0.05951530486345291, 0.08816319704055786, 0.15935833752155304, 0.0644921213388443, -0.07194079458713531, -0.5383226871490479, 0.17800968885421753, -0.195652037858963, -0.028597159311175346, 0.08582349121570587, -0.23225288093090057, -0.12984338402748108, 0.3651025593280792, -0.4039592146873474, -0.3628298342227936, 0.08263863623142242, -0.12648534774780273, -0.08284908533096313, -0.1042669266462326, -0.4579034447669983, -0.2961195111274719, -0.32282471656799316, 0.3182551860809326, -0.6890494227409363, -0.7114676237106323, 2.3665072841905432e-32, -0.0030965525656938553, -0.5696439146995544, -0.5794872045516968, 0.04729880392551422, -0.048917483538389206, -0.10963250696659088, 0.298623263835907, 0.4452674388885498, -0.2828809320926666, 0.5696343183517456, 0.3004711866378784, 0.44842660427093506, 0.06550214439630508, -0.020054858177900314, 0.385932058095932, -0.23460465669631958, 0.23865005373954773, 0.4363722801208496, -0.24931970238685608, -0.41073542833328247, -0.2937365770339966, 0.5095447301864624, 0.2864843010902405, -0.14028388261795044, -0.14269764721393585, 0.4107881486415863, -0.2581801116466522, 0.18544888496398926, -0.08612997084856033, 0.33715111017227173, -0.24288496375083923, 0.3599962592124939, -0.43829354643821716, 0.15094976127147675, 0.03177203983068466, 0.5965112447738647, 0.03364168107509613, -0.5481097102165222, -0.363423228263855, 0.4825053811073303, -0.7288467288017273, -0.13361915946006775, 0.7423286437988281, -0.3515661358833313, -0.37989044189453125, -0.1576842963695526, 0.3734908998012543, 0.8393698930740356, 0.23719121515750885, -0.28990280628204346, 0.11215505003929138, -0.16382968425750732, 0.47951722145080566, 0.28471529483795166, 0.5308315753936768, -0.1286555975675583, -0.22689077258110046, 0.6377706527709961, 0.34224453568458557, 0.07091143727302551, 0.26538553833961487, 0.014475930482149124, -0.050034329295158386, 0.011025313287973404, 0.09357182681560516, 0.1345357596874237, -0.1523902863264084, 0.14176052808761597, -0.0609259307384491, -0.3332745134830475, -0.1072426363825798, -0.5933747291564941, -0.40028926730155945, 0.5343422293663025, 0.016202416270971298, 0.27436596155166626, 0.28844428062438965, -0.1660136878490448, -0.6286065578460693, 0.5850632190704346, -0.6491153836250305, -0.03207448124885559, 0.23312292993068695, 0.09339666366577148, -0.42595869302749634, -0.5011518001556396, 0.08187201619148254, -0.3312609791755676, -0.3677852153778076, -0.3758619427680969, -0.12195874005556107, -0.014479270204901695, -0.014539752155542374, 0.23270025849342346, -0.3609132170677185, -9.438503667524856e-8, -0.05230816453695297, 0.17612962424755096, 0.01489749364554882, 0.06601762771606445, -0.14300350844860077, -0.1422577053308487, 0.7347333431243896, 0.030603498220443726, 0.24959787726402283, 0.026135217398405075, -0.4412609338760376, -0.18663707375526428, -0.29235413670539856, 0.4696626365184784, 0.12353914976119995, -0.3236965537071228, -0.6856554746627808, -0.28768694400787354, 0.0671629011631012, 0.27566438913345337, -0.0893339067697525, -0.22328855097293854, -0.16536207497119904, -0.08968719840049744, 0.022607458755373955, 0.21818216145038605, -0.14408129453659058, 0.14458191394805908, 0.4712568521499634, 0.13527995347976685, 0.16118602454662323, 0.23675017058849335, -0.0062652211636304855, -0.4045848250389099, -0.5631943345069885, 0.04897312819957733, -0.2558498978614807, 0.5269845128059387, -0.16870160400867462, -0.39874112606048584, 0.3996037244796753, 0.5432316660881042, -0.3740345239639282, 0.031965695321559906, 0.29769593477249146, 0.1568443477153778, 0.287019282579422, 0.6005253791809082, -0.33905476331710815, -0.07407552748918533, -0.4541633129119873, 0.047827333211898804, 0.4803982973098755, -0.2860602140426636, 0.17097190022468567, -0.7525586485862732, -0.06290972977876663, 0.14645379781723022, 0.176426962018013, 0.024587953463196754, 0.105128213763237, 0.023733407258987427, -0.1363760083913803, 0.22127331793308258,
+	}
+	sim := cosineSimilarity(res.Embedding, expected)
+	if sim < 0.99 {
+		t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embedding[0:5], sim)
 	}
 }
 
@@ -66,8 +82,12 @@ func TestAllMiniLMEmbed(t *testing.T) {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
 	}
 
-	if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
-		t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
+	expected := []float32{
+		0.010071031, -0.0017594865, 0.050072223, 0.046929732, 0.05491682, 0.008599705, 0.105441436, -0.025878143, 0.1295813, 0.031952355, -0.04448072, -0.0089852745, -0.000509909, -0.06374169, -0.016089523, 0.04662509, -0.022060998, -0.15813895, -0.072848774, -0.061321855, -0.065877646, 0.054177605, -0.06213012, 0.038908366, -0.04580116, 0.05493584, -0.035267256, 0.012613296, 0.04251382, -0.007927403, -0.01902945, 0.060983833, 0.036926776, 0.013464811, -0.025808964, -0.043487485, 0.072623335, -0.04850803, 0.00428558, -0.02943825, -0.02913489, -0.03290691, -0.018345183, 0.0155583285, -0.011713048, 0.01530367, -0.009391865, 0.025963927, 0.09527476, -0.015497632, -0.024581224, 0.009084283, -0.07661165, 0.015987588, 0.049554788, 0.115980916, 0.0009802427, -0.02031978, 0.09233272, 0.00849488, -0.05705784, 0.068866335, -0.076607056, 0.06931919, 0.09223656, -0.055486195, -0.053620946, 0.008443246, -0.06315959, -0.066396914, -0.02516728, 0.018891005, 0.061389998, -0.028247874, 0.036244337, 0.0011042351, 0.06067215, -0.06755123, -0.008126048, -0.012737444, 0.030953258, -0.06380051, -0.07451028, 0.1191656, 0.012553826, 0.06532671, 0.014824665, 0.051425762, -0.08518537, 0.010257597, -0.0077732494, -0.035585348, -0.115389846, -0.03066639, -0.0832527, 0.013689985, 0.056588713, -0.040882625, 0.042672798, 0.022154681, 0.04685385, -0.05135596, 0.030175874, 0.007199854, -0.0041790465, -0.031146567, 0.07788334, 0.034205843, 0.06138031, 0.007510951, -0.036251485, -0.08457674, 0.021795211, -0.019397866, -0.03984967, 0.054795727, -0.033695232, 0.018102817, -0.10553994, -0.050397146, -0.011542906, 0.0378195, 0.022170838, 0.08049212, 0.007816837, -0.01683443, -0.059413332, -7.227309e-33, 0.13531439, -0.011213897, 0.0923026, 0.03597459, 0.039638437, -0.054985173, -0.03506899, -0.0037263383, -0.01955998, -0.034966808, -0.0057084337, -0.014629069, -0.024276787, -0.048383784, 0.04777095, -0.017076956, -0.06094759, 0.0059446157, -0.083057985, 0.084341705, -0.1046656, 0.041639294, -0.03668315, -0.008083383, -0.028216336, -0.04319357, 0.035999607, 0.07498755, 0.05645381, 0.011849057, 0.09846523, 0.10484252, -0.021864949, 0.045994766, -0.026346037, -0.05092382, -0.014708711, -0.0063834875, -0.085867085, 0.028602734, -0.0535738, 0.056528863, -0.059763853, 0.012410302, 0.06620772, -0.013472636, 0.038324803, -0.08890202, -0.05744544, 0.03199372, -0.034495477, 0.02363032, 0.014458106, -0.04159657, 0.06799366, 0.031207295, 0.069696635, -0.035037853, -0.0033100948, 0.0493309, -0.0133445235, -0.0034971808, 0.050776623, 0.078672916, 0.037620574, -0.011580864, 0.03812419, 0.04201406, -0.012800006, -0.07894726, 0.00902281, 0.013365969, 0.024159499, 0.009777319, -0.010906574, -0.08161233, 0.026987134, -0.0296618, -0.004335468, 0.013011258, -0.035210665, -0.019684888, 0.055351324, -0.06124218, -0.055006765, 0.012528419, -0.019175794, -0.012560324, -0.015807373, -0.06942039, -0.044893157, -0.048941795, 0.048249032, -0.10446324, -0.10786195, 3.58774e-33, -0.0004694524, -0.08636079, -0.087853074, 0.0071707284, -0.007416128, -0.01662082, 0.045272738, 0.06750471, -0.042886123, 0.08635933, 0.04555289, 0.06798365, 0.009930444, -0.003040414, 0.058509175, -0.035567205, 0.036180507, 0.06615616, -0.03779808, -0.062269486, -0.044531893, 0.07724946, 0.04343241, -0.021267718, -0.021633657, 0.06227748, -0.03914136, 0.028114952, -0.013057723, 0.051113747, -0.036822543, 0.054577183, -0.06644743, 0.022884717, 0.0048167957, 0.09043401, 0.0051002423, -0.083096094, -0.055096727, 0.07315016, -0.11049671, -0.020257315, 0.11254063, -0.053299136, -0.057593238, -0.023905706, 0.056623034, 0.12725255, 0.03595934, -0.043950673, 0.017003251, -0.024837377, 0.07269714, 0.043164223, 0.08047665, -0.019504813, -0.034397744, 0.096689135, 0.051885936, 0.010750518, 0.04023374, 0.0021946214, -0.0075854477, 0.0016714911, 0.014185944, 0.020396275, -0.023103109, 0.021491585, -0.009236667, -0.050526038, -0.016258504, -0.0899585, -0.0606858, 0.08100888, 0.0024563652, 0.041595213, 0.043729555, -0.025168482, -0.09529981, 0.088698424, -0.09840905, -0.0048626475, 0.03534257, 0.014159388, -0.06457741, -0.07597705, 0.012412196, -0.050220776, -0.055758025, -0.0569825, -0.018489538, -0.0021951278, -0.002204297, 0.03527849, -0.0547162, -1.430923e-8, -0.007930172, 0.026702108, 0.0022585324, 0.010008593, -0.021680027, -0.02156696, 0.111389145, 0.004639639, 0.03784025, 0.003962226, -0.0668973, -0.028295087, -0.04432231, 0.07120314, 0.018729135, -0.04907397, -0.103948705, -0.043614738, 0.010182222, 0.04179206, -0.013543455, -0.03385163, -0.025069695, -0.013597015, 0.0034274007, 0.033077475, -0.021843424, 0.021919321, 0.07144483, 0.020509098, 0.024436586, 0.035892475, -0.00094983797, -0.061337028, -0.085383, 0.007424564, -0.038788088, 0.07989341, -0.025575982, -0.060451094, 0.060581867, 0.082356565, -0.056705453, 0.0048461547, 0.04513215, 0.023778366, 0.043513518, 0.09104256, -0.05140235, -0.01123021, -0.06885336, 0.007250856, 0.072830714, -0.04336812, 0.025920171, -0.11409155, -0.009537421, 0.022203108, 0.026747186, 0.0037276533, 0.015937949, 0.0035980998, -0.020675266, 0.03354611,
+	}
+	sim := cosineSimilarity(res.Embeddings[0], expected)
+	if sim < 0.99 {
+		t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
 	}
 
 	if res.PromptEvalCount != 6 {
@@ -98,8 +118,22 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
 		t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
 	}
 
-	if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
-		t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
+	expected := [][]float32{
+		{
+			0.010071031, -0.0017594865, 0.050072223, 0.046929732, 0.05491682, 0.008599705, 0.105441436, -0.025878143, 0.1295813, 0.031952355, -0.04448072, -0.0089852745, -0.000509909, -0.06374169, -0.016089523, 0.04662509, -0.022060998, -0.15813895, -0.072848774, -0.061321855, -0.065877646, 0.054177605, -0.06213012, 0.038908366, -0.04580116, 0.05493584, -0.035267256, 0.012613296, 0.04251382, -0.007927403, -0.01902945, 0.060983833, 0.036926776, 0.013464811, -0.025808964, -0.043487485, 0.072623335, -0.04850803, 0.00428558, -0.02943825, -0.02913489, -0.03290691, -0.018345183, 0.0155583285, -0.011713048, 0.01530367, -0.009391865, 0.025963927, 0.09527476, -0.015497632, -0.024581224, 0.009084283, -0.07661165, 0.015987588, 0.049554788, 0.115980916, 0.0009802427, -0.02031978, 0.09233272, 0.00849488, -0.05705784, 0.068866335, -0.076607056, 0.06931919, 0.09223656, -0.055486195, -0.053620946, 0.008443246, -0.06315959, -0.066396914, -0.02516728, 0.018891005, 0.061389998, -0.028247874, 0.036244337, 0.0011042351, 0.06067215, -0.06755123, -0.008126048, -0.012737444, 0.030953258, -0.06380051, -0.07451028, 0.1191656, 0.012553826, 0.06532671, 0.014824665, 0.051425762, -0.08518537, 0.010257597, -0.0077732494, -0.035585348, -0.115389846, -0.03066639, -0.0832527, 0.013689985, 0.056588713, -0.040882625, 0.042672798, 0.022154681, 0.04685385, -0.05135596, 0.030175874, 0.007199854, -0.0041790465, -0.031146567, 0.07788334, 0.034205843, 0.06138031, 0.007510951, -0.036251485, -0.08457674, 0.021795211, -0.019397866, -0.03984967, 0.054795727, -0.033695232, 0.018102817, -0.10553994, -0.050397146, -0.011542906, 0.0378195, 0.022170838, 0.08049212, 0.007816837, -0.01683443, -0.059413332, -7.227309e-33, 0.13531439, -0.011213897, 0.0923026, 0.03597459, 0.039638437, -0.054985173, -0.03506899, -0.0037263383, -0.01955998, -0.034966808, -0.0057084337, -0.014629069, -0.024276787, -0.048383784, 0.04777095, -0.017076956, -0.06094759, 0.0059446157, -0.083057985, 0.084341705, -0.1046656, 0.041639294, -0.03668315, -0.008083383, -0.028216336, -0.04319357, 0.035999607, 0.07498755, 0.05645381, 0.011849057, 0.09846523, 0.10484252, -0.021864949, 0.045994766, -0.026346037, -0.05092382, -0.014708711, -0.0063834875, -0.085867085, 0.028602734, -0.0535738, 0.056528863, -0.059763853, 0.012410302, 0.06620772, -0.013472636, 0.038324803, -0.08890202, -0.05744544, 0.03199372, -0.034495477, 0.02363032, 0.014458106, -0.04159657, 0.06799366, 0.031207295, 0.069696635, -0.035037853, -0.0033100948, 0.0493309, -0.0133445235, -0.0034971808, 0.050776623, 0.078672916, 0.037620574, -0.011580864, 0.03812419, 0.04201406, -0.012800006, -0.07894726, 0.00902281, 0.013365969, 0.024159499, 0.009777319, -0.010906574, -0.08161233, 0.026987134, -0.0296618, -0.004335468, 0.013011258, -0.035210665, -0.019684888, 0.055351324, -0.06124218, -0.055006765, 0.012528419, -0.019175794, -0.012560324, -0.015807373, -0.06942039, -0.044893157, -0.048941795, 0.048249032, -0.10446324, -0.10786195, 3.58774e-33, -0.0004694524, -0.08636079, -0.087853074, 0.0071707284, -0.007416128, -0.01662082, 0.045272738, 0.06750471, -0.042886123, 0.08635933, 0.04555289, 0.06798365, 0.009930444, -0.003040414, 0.058509175, -0.035567205, 0.036180507, 0.06615616, -0.03779808, -0.062269486, -0.044531893, 0.07724946, 0.04343241, -0.021267718, -0.021633657, 0.06227748, -0.03914136, 0.028114952, -0.013057723, 0.051113747, -0.036822543, 0.054577183, -0.06644743, 0.022884717, 0.0048167957, 0.09043401, 0.0051002423, -0.083096094, -0.055096727, 0.07315016, -0.11049671, -0.020257315, 0.11254063, -0.053299136, -0.057593238, -0.023905706, 0.056623034, 0.12725255, 0.03595934, -0.043950673, 0.017003251, -0.024837377, 0.07269714, 0.043164223, 0.08047665, -0.019504813, -0.034397744, 0.096689135, 0.051885936, 0.010750518, 0.04023374, 0.0021946214, -0.0075854477, 0.0016714911, 0.014185944, 0.020396275, -0.023103109, 0.021491585, -0.009236667, -0.050526038, -0.016258504, -0.0899585, -0.0606858, 0.08100888, 0.0024563652, 0.041595213, 0.043729555, -0.025168482, -0.09529981, 0.088698424, -0.09840905, -0.0048626475, 0.03534257, 0.014159388, -0.06457741, -0.07597705, 0.012412196, -0.050220776, -0.055758025, -0.0569825, -0.018489538, -0.0021951278, -0.002204297, 0.03527849, -0.0547162, -1.430923e-8, -0.007930172, 0.026702108, 0.0022585324, 0.010008593, -0.021680027, -0.02156696, 0.111389145, 0.004639639, 0.03784025, 0.003962226, -0.0668973, -0.028295087, -0.04432231, 0.07120314, 0.018729135, -0.04907397, -0.103948705, -0.043614738, 0.010182222, 0.04179206, -0.013543455, -0.03385163, -0.025069695, -0.013597015, 0.0034274007, 0.033077475, -0.021843424, 0.021919321, 0.07144483, 0.020509098, 0.024436586, 0.035892475, -0.00094983797, -0.061337028, -0.085383, 0.007424564, -0.038788088, 0.07989341, -0.025575982, -0.060451094, 0.060581867, 0.082356565, -0.056705453, 0.0048461547, 0.04513215, 0.023778366, 0.043513518, 0.09104256, -0.05140235, -0.01123021, -0.06885336, 0.007250856, 0.072830714, -0.04336812, 0.025920171, -0.11409155, -0.009537421, 0.022203108, 0.026747186, 0.0037276533, 0.015937949, 0.0035980998, -0.020675266, 0.03354611,
+		},
+		{
+			-0.009802706, 0.060424678, 0.025257956, -0.0063643856, 0.07272723, 0.01719488, 0.090320334, -0.051705167, 0.099515095, 0.09072479, 0.007301506, -0.01968127, -0.075095184, -0.017409375, 0.019365614, 0.040805466, -0.011079843, -0.05856395, -0.12545314, -0.048980292, -0.044052314, 0.03115607, 0.037880868, -0.03187379, -0.0909825, 0.06357952, -0.076541565, 0.085011445, 0.03554875, -0.071272224, 0.021114277, 0.11005397, 0.03312636, -0.025947863, -0.061563145, -0.026466936, 0.02054478, -0.05426622, 0.056569945, 0.03292456, -0.09005933, -0.05698778, 0.026827272, 0.0751872, -0.07142025, -0.0043633, 0.054151993, 0.026441583, 0.078053534, -0.048995998, 0.056577347, -0.048973206, -0.07581186, 0.006902122, 0.0062451144, 0.037024222, 0.025028007, 0.021724675, 0.010117283, -0.040492155, -0.012010403, -0.03334674, -0.07570402, 0.071321115, -0.02062346, -0.0631419, -0.001237942, -0.055173304, 0.009124682, -0.08703634, 0.020684991, 0.05294139, -0.009563882, -0.052647192, -0.06467313, 0.041968923, 0.04473555, 0.03270584, -0.019611169, 0.00013324046, 0.038228948, 0.0509972, 0.0047100335, 0.05736671, 0.046469305, 0.04269017, -0.017305125, 0.011859765, -0.05701112, -0.03498464, -0.018940303, -0.0074608736, -0.07385685, 0.043892473, -0.09890047, 0.041379265, -0.024019944, -0.12034819, 0.0001821356, -0.0038607453, 0.056144036, -0.0005059898, 0.07110965, -0.03616245, -0.06406574, -0.009435536, -0.042290587, 0.07791005, -0.02365763, 0.007864432, -0.023739463, -0.018536761, -0.033538047, 0.0776669, -0.06058719, 0.05363198, 0.033863083, 0.012545284, -0.03260245, 0.029770961, -0.016934512, 0.028213669, -0.018053731, 0.06651968, -0.06952628, -0.017853932, -0.037421644, -6.839719e-33, -0.0055490523, -0.031681225, 0.04819487, -0.09944883, 0.09372583, -0.051811725, -0.037059266, -0.026262678, -0.037466466, -0.030253021, 0.0060922937, -0.09831781, -0.017570594, -0.07247917, 0.03856134, 0.00888377, -0.13072893, 0.02145255, -0.075681135, -0.010470858, -0.017236665, 0.058358245, 0.022016024, 0.0015762328, 0.009419801, -0.031423207, 0.08002972, 0.030580623, 0.05696977, -0.012164853, 0.11575935, 0.0040441174, 0.01759827, 0.043209996, 0.02948431, -0.0069428794, -0.025078153, -0.026160793, 0.013364178, 0.121543564, -0.004469769, -0.04534167, 0.043418996, -0.01768049, 0.062162045, -0.039375506, 0.017406953, 0.008458191, -0.02603069, 0.010130821, 0.023227274, 0.05305319, 0.06899141, 0.053088874, -0.0003113895, 0.009642751, 0.08884011, -0.030399954, -0.090916164, -0.051467095, -0.07382789, 0.08624027, 0.003223033, 0.010827092, -0.008318035, -0.011421701, -0.02900046, 0.06548931, 0.005405483, 0.068780296, 0.0428464, -0.01878741, -0.016996592, -0.036818627, -0.0062817424, -0.08700542, -0.008640271, -0.013171244, -0.004574588, 0.04233393, -0.03579696, 0.017357353, -0.087162524, -0.050884914, -0.14957926, -0.002008126, -0.02634847, 0.018098367, 0.02162604, -0.01503002, 0.0037868456, -0.015445877, -0.013303974, -0.09810386, -0.011673153, 2.8261164e-33, -0.022961555, 0.0090464745, -0.0057421196, 0.06604244, 0.042683356, -0.039691485, 0.027226122, 0.03183442, -0.028517157, 0.045575514, -0.055865873, 0.0924774, -0.046869125, 0.08027759, 0.118624836, 0.04889292, -0.06734586, 0.10688813, 0.009396721, -0.051344905, -0.067946814, 0.01592692, -0.010147019, 0.044173665, -0.030018767, 0.022772646, -0.031494025, -0.02233876, -0.0023573847, -0.010024354, 0.0032828946, -0.036839407, -0.11200184, 0.028629173, 0.030212566, 0.03185506, -0.01746865, -0.018295743, -0.036361173, 0.083925165, 0.007943152, -0.023664381, 0.15850149, 0.032088134, -0.070371404, -0.034124147, -0.015502377, 0.07960292, -0.06218589, 0.046537183, 0.04505064, 0.1043822, 0.029607052, 0.047920443, 0.09711685, -0.015767856, -0.064267434, 0.01960162, -0.093837254, -0.0028061024, 0.019721054, -0.027095793, -0.078636706, 0.0689579, 0.107794516, -0.033122607, -0.064406104, 0.016571952, 0.019280795, -0.023045482, -0.018821374, -0.018646069, -0.06431513, -0.03231013, -0.0027636476, 0.059007723, 0.059882853, -0.044795096, -0.06667144, 0.043793377, -0.019855661, -0.006715758, 0.04733659, -0.046866804, 0.03461545, -0.015199261, -0.039511763, 0.047361404, 0.052113988, 0.0008203065, 0.05290727, 0.02459614, -0.029357709, 0.034541644, 0.013009169, -1.36748e-8, -0.033930536, 0.007378359, -0.010701883, 0.04323486, 0.014735074, -0.04162692, 0.10553509, -0.012822099, -0.002357336, 0.040418625, -0.08136588, 0.033679843, -0.019665385, 0.077529214, 0.060347307, -0.016181026, -0.11332622, -0.04306442, 0.023209568, 0.07448782, -0.06055759, -0.045812756, -0.087526724, 0.0534105, -0.044014834, 0.029827949, 0.038628686, 0.016933717, 0.027725562, 0.078133695, 0.055581007, 0.05306717, -0.010792625, -0.029803185, -0.08492531, -0.016416015, 0.030501937, 0.06944753, -0.061944496, -0.122021444, 0.011901371, 0.07258673, -0.017778289, 0.0030972173, 0.014411535, -0.03802866, -0.052976213, 0.060414705, -0.053164586, 0.01794129, -0.104411006, 0.010633235, 0.042881854, 0.042603284, -0.003009017, -0.08530093, -0.039561126, -0.004481811, 0.013104284, -0.008498699, -0.028943708, -0.03587923, 0.05940551, -0.000055299755,
+		},
+	}
+
+	sim := cosineSimilarity(res.Embeddings[0], expected[0])
+	if sim < 0.99 {
+		t.Fatalf("expected %v, got %v (similarity: %f)", expected[0][0:5], res.Embeddings[0][0:5], sim)
+	}
+	sim = cosineSimilarity(res.Embeddings[1], expected[1])
+	if sim < 0.99 {
+		t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
 	}
 
 	if res.PromptEvalCount != 12 {
diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go
index d0c861cc0..c7b56890e 100644
--- a/integration/llm_image_test.go
+++ b/integration/llm_image_test.go
@@ -12,7 +12,7 @@ import (
 	"github.com/stretchr/testify/require"
 )
 
-func TestIntegrationMultimodal(t *testing.T) {
+func TestIntegrationLlava(t *testing.T) {
 	image, err := base64.StdEncoding.DecodeString(imageEncoding)
 	require.NoError(t, err)
 	req := api.GenerateRequest{
@@ -39,6 +39,33 @@ func TestIntegrationMultimodal(t *testing.T) {
 	DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
 }
 
+func TestIntegrationMllama(t *testing.T) {
+	image, err := base64.StdEncoding.DecodeString(imageEncoding)
+	require.NoError(t, err)
+	req := api.GenerateRequest{
+		// TODO fix up once we publish the final image
+		Model:  "x/llama3.2-vision",
+		Prompt: "what does the text in this image say?",
+		Stream: &stream,
+		Options: map[string]interface{}{
+			"seed":        42,
+			"temperature": 0.0,
+		},
+		Images: []api.ImageData{
+			image,
+		},
+	}
+
+	resp := "the ollamas"
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
+	defer cancel()
+	client, _, cleanup := InitServerConnection(ctx, t)
+	defer cleanup()
+	require.NoError(t, PullIfMissing(ctx, client, req.Model))
+	// mllama models on CPU can be quite slow to start,
+	DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
+}
+
 const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
 AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA
 AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6
diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go
index ec9e085a5..1878d0dac 100644
--- a/integration/max_queue_test.go
+++ b/integration/max_queue_test.go
@@ -16,23 +16,18 @@ import (
 	"github.com/stretchr/testify/require"
 
 	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/envconfig"
 )
 
 func TestMaxQueue(t *testing.T) {
 	if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
-		t.Skip("Max Queue test requires spawing a local server so we can adjust the queue size")
+		t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
 		return
 	}
 
 	// Note: This test can be quite slow when running in CPU mode, so keep the threadCount low unless your on GPU
 	// Also note that by default Darwin can't sustain > ~128 connections without adjusting limits
-	threadCount := 32
-	if maxQueue := envconfig.MaxQueue(); maxQueue != 0 {
-		threadCount = int(maxQueue)
-	} else {
-		t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
-	}
+	threadCount := 16
+	t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
 
 	req := api.GenerateRequest{
 		Model:  "orca-mini",
@@ -72,7 +67,7 @@ func TestMaxQueue(t *testing.T) {
 	busyCount := 0
 	resetByPeerCount := 0
 	canceledCount := 0
-	succesCount := 0
+	successCount := 0
 	counterMu := sync.Mutex{}
 	var embedwg sync.WaitGroup
 	for i := 0; i < threadCount; i++ {
@@ -93,7 +88,7 @@ func TestMaxQueue(t *testing.T) {
 			defer counterMu.Unlock()
 			switch {
 			case genErr == nil:
-				succesCount++
+				successCount++
 				require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
 			case errors.Is(genErr, context.Canceled):
 				canceledCount++
@@ -112,7 +107,7 @@ func TestMaxQueue(t *testing.T) {
 	slog.Info("generate done, waiting for embeds")
 	embedwg.Wait()
 
-	slog.Info("embeds completed", "success", succesCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
+	slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
 	require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
 	require.True(t, busyCount > 0, "no requests hit busy error but some should have")
 	require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
diff --git a/integration/utils_test.go b/integration/utils_test.go
index a60109958..e76b63f2a 100644
--- a/integration/utils_test.go
+++ b/integration/utils_test.go
@@ -275,7 +275,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
 				break
 			}
 		}
-		require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
+		require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
 		slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
 	case <-ctx.Done():
 		t.Error("outer test context done while waiting for generate")
diff --git a/llama/.gitignore b/llama/.gitignore
new file mode 100644
index 000000000..ee5520435
--- /dev/null
+++ b/llama/.gitignore
@@ -0,0 +1,3 @@
+*.bin
+*.gguf
+build/
\ No newline at end of file
diff --git a/llama/README.md b/llama/README.md
new file mode 100644
index 000000000..3b6b20672
--- /dev/null
+++ b/llama/README.md
@@ -0,0 +1,160 @@
+# `llama`
+
+This package integrates the [llama.cpp](https://github.com/ggerganov/llama.cpp) library as a Go package and makes it easy to build it with tags for different CPU and GPU processors.
+
+Supported:
+
+- [x] CPU
+- [x] avx, avx2
+- [x] macOS Metal
+- [x] Windows CUDA
+- [x] Windows ROCm
+- [x] Linux CUDA
+- [x] Linux ROCm
+- [x] Llava
+
+Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these shared libraries are created:
+
+- `ggml_cuda.dll` on Windows or `ggml_cuda.so` on Linux
+- `ggml_hipblas.dll` on Windows or `ggml_hipblas.so` on Linux
+
+> Note: it's important that memory is allocated and freed by the same compiler (e.g. entirely by code compiled with msvc or mingw). Issues from this should be rare, but there are some places where pointers are returned by the CUDA or HIP runtimes and freed elsewhere, causing a a crash. In a future change the same runtime should be used in both cases to avoid crashes.
+
+## Building
+
+```
+go build .
+```
+
+### AVX
+
+```shell
+go build -tags avx .
+```
+
+### AVX2
+
+```shell
+# go doesn't recognize `-mfma` as a valid compiler flag
+# see https://github.com/golang/go/issues/17895
+go env -w "CGO_CFLAGS_ALLOW=-mfma|-mf16c"
+go env -w "CGO_CXXFLAGS_ALLOW=-mfma|-mf16c"
+go build -tags=avx,avx2 .
+```
+
+## Linux
+
+### CUDA
+
+Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive):
+
+```shell
+make ggml_cuda.so
+go build -tags avx,cuda .
+```
+
+### ROCm
+
+Install [ROCm](https://rocm.docs.amd.com/en/latest/).
+
+```shell
+make ggml_hipblas.so
+go build -tags avx,rocm .
+```
+
+## Windows
+
+Download [w64devkit](https://github.com/skeeto/w64devkit/releases/latest) for a simple MinGW development environment.
+
+### CUDA
+
+Install the [CUDA toolkit v11.3.1](https://developer.nvidia.com/cuda-11-3-1-download-archive) then build the cuda code:
+
+```shell
+make ggml_cuda.dll
+go build -tags avx,cuda .
+```
+
+### ROCm
+
+Install [ROCm](https://rocm.docs.amd.com/en/latest/).
+
+```shell
+make ggml_hipblas.dll
+go build -tags avx,rocm .
+```
+
+## Building runners
+
+```shell
+# build all runners for this platform
+make -j
+```
+
+## Vendoring
+
+Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
+
+If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
+
+```
+make apply-patches
+```
+
+### Updating Base Commit
+
+**Pin to new base commit**
+
+To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring`
+
+#### Applying patches
+
+When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
+
+Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
+
+```
+make apply-patches
+```
+
+If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
+
+```
+make create-patches sync
+```
+
+Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
+
+### Generating Patches
+
+When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
+
+```
+make apply-patches
+```
+
+Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
+
+```
+make sync
+make -j 8
+go build .
+```
+
+> [!IMPORTANT]
+> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
+
+Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
+
+```
+make create-patches
+```
+
+> [!IMPORTANT]
+> Once you have completed this step, it is safe to run `apply-patches` since your change is preserved in the patches.
+
+In your `./vendor/` directory, create a branch, and cherry-pick the new commit to that branch, then submit a PR upstream to llama.cpp.
+
+Commit the changes in the ollama repo and submit a PR to Ollama, which will include the vendored code update with your change, along with the patches.
+
+After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already.
diff --git a/llama/amx.cpp b/llama/amx.cpp
new file mode 100644
index 000000000..a2c7e8e5a
--- /dev/null
+++ b/llama/amx.cpp
@@ -0,0 +1,246 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "amx.h"
+#include "common.h"
+#include "mmq.h"
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-cpu-traits.h"
+
+#if defined(__gnu_linux__)
+#include 
+#include 
+#endif
+
+#include 
+#include 
+#include 
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+
+// AMX type_trais
+namespace ggml::cpu::amx {
+class tensor_traits : public ggml::cpu::tensor_traits {
+    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        size = ggml_backend_amx_desired_wsize(op);
+        return true;
+    }
+
+    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
+        if (op->op == GGML_OP_MUL_MAT) {
+            ggml_backend_amx_mul_mat(params, op);
+            return true;
+        }
+        return false;
+    }
+};
+
+static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
+    static tensor_traits traits;
+    return &traits;
+}
+}  // namespace ggml::cpu::amx
+
+// AMX buffer interface
+static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    free(buffer->context);
+}
+
+static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return (void *) (buffer->context);
+}
+
+static void ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
+                                                  uint8_t value, size_t offset, size_t size) {
+    memset((char *) tensor->data + offset, value, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
+                                               const void * data, size_t offset, size_t size) {
+    if (qtype_has_amx_kernels(tensor->type)) {
+        GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type));
+        ggml_backend_amx_convert_weight(tensor, data, offset, size);
+    } else {
+        memcpy((char *) tensor->data + offset, data, size);
+    }
+
+    GGML_UNUSED(buffer);
+}
+
+/*
+// need to figure what we need to do with buffer->extra.
+static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        if (qtype_has_amx_kernels(src->type)) {
+            ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst));
+        } else {
+            memcpy(dst->data, src->data, ggml_nbytes(src));
+        }
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+*/
+
+static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    memset(buffer->context, value, buffer->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_amx_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_amx_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_amx_buffer_init_tensor,
+    /* .memset_tensor   = */ ggml_backend_amx_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_amx_buffer_set_tensor,
+    /* .get_tensor      = */ nullptr,
+    /* .cpy_tensor      = */ nullptr,
+    /* .clear           = */ ggml_backend_amx_buffer_clear,
+    /* .reset           = */ nullptr,
+};
+
+static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "AMX";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * data = ggml_aligned_malloc(size);
+    if (data == NULL) {
+        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+        return NULL;
+    }
+
+    return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
+}
+
+static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+namespace ggml::cpu::amx {
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
+    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        // handle only 2d gemm for now
+        auto is_contiguous_2d = [](const struct ggml_tensor * t) {
+            return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
+        };
+
+        if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) &&  // src0 must be contiguous
+            is_contiguous_2d(op->src[1]) &&                               // src1 must be contiguous
+            op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
+            op->ne[0] % (TILE_N * 2) == 0 &&                              // out_features is 32x
+            (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
+            // src1 must be host buffer
+            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+                return false;
+            }
+            // src1 must be float32
+            if (op->src[1]->type == GGML_TYPE_F32) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
+        if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&
+            op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) {
+            return (ggml::cpu::tensor_traits *) op->src[0]->extra;
+        }
+
+        return nullptr;
+    }
+};
+}  // namespace ggml::cpu::amx
+
+static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    return ggml_backend_amx_get_alloc_size(tensor);
+
+    GGML_UNUSED(buft);
+}
+
+#define ARCH_GET_XCOMP_PERM     0x1022
+#define ARCH_REQ_XCOMP_PERM     0x1023
+#define XFEATURE_XTILECFG       17
+#define XFEATURE_XTILEDATA      18
+
+static bool ggml_amx_init() {
+#if defined(__gnu_linux__)
+    if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
+        fprintf(stderr, "AMX is not ready to be used!\n");
+        return false;
+    }
+    return true;
+#elif defined(_WIN32)
+    return true;
+#endif
+}
+
+ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
+        /* .iface = */ {
+                        /* .get_name         = */ ggml_backend_amx_buffer_type_get_name,
+                        /* .alloc_buffer     = */ ggml_backend_amx_buffer_type_alloc_buffer,
+                        /* .get_alignment    = */ ggml_backend_amx_buffer_type_get_alignment,
+                        /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
+                        /* .get_alloc_size   = */ ggml_backend_amx_buffer_type_get_alloc_size,
+                        /* .is_host          = */ nullptr,
+                        },
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
+    };
+
+    if (!ggml_amx_init()) {
+        return nullptr;
+    }
+
+    return &ggml_backend_buffer_type_amx;
+}
+
+#endif  // defined(__AMX_INT8__) && defined(__AVX512VNNI__)
diff --git a/llama/amx.h b/llama/amx.h
new file mode 100644
index 000000000..5b64b8bd8
--- /dev/null
+++ b/llama/amx.h
@@ -0,0 +1,34 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend.h"
+#include "ggml-cpu-impl.h"
+
+// GGML internal header
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
+#endif
diff --git a/llama/base64.hpp b/llama/base64.hpp
new file mode 100644
index 000000000..563247a6e
--- /dev/null
+++ b/llama/base64.hpp
@@ -0,0 +1,392 @@
+/*
+This is free and unencumbered software released into the public domain.
+
+Anyone is free to copy, modify, publish, use, compile, sell, or
+distribute this software, either in source code form or as a compiled
+binary, for any purpose, commercial or non-commercial, and by any
+means.
+
+In jurisdictions that recognize copyright laws, the author or authors
+of this software dedicate any and all copyright interest in the
+software to the public domain. We make this dedication for the benefit
+of the public at large and to the detriment of our heirs and
+successors. We intend this dedication to be an overt act of
+relinquishment in perpetuity of all present and future rights to this
+software under copyright law.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
+OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
+
+For more information, please refer to 
+*/
+
+#ifndef PUBLIC_DOMAIN_BASE64_HPP_
+#define PUBLIC_DOMAIN_BASE64_HPP_
+
+#include 
+#include 
+#include 
+#include 
+
+class base64_error : public std::runtime_error
+{
+public:
+    using std::runtime_error::runtime_error;
+};
+
+class base64
+{
+public:
+    enum class alphabet
+    {
+        /** the alphabet is detected automatically */
+        auto_,
+        /** the standard base64 alphabet is used */
+        standard,
+        /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
+        url_filename_safe
+    };
+
+    enum class decoding_behavior
+    {
+        /** if the input is not padded, the remaining bits are ignored */
+        moderate,
+        /** if a padding character is encounter decoding is finished */
+        loose
+    };
+
+    /**
+     Encodes all the elements from `in_begin` to `in_end` to `out`.
+
+     @warning The source and destination cannot overlap. The destination must be able to hold at least
+     `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
+
+     @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
+     8 bits
+     @tparam Output_iterator the destination; the elements written to it are from the type `char`
+     @param in_begin the beginning of the source
+     @param in_end the ending of the source
+     @param out the destination iterator
+     @param alphabet which alphabet should be used
+     @returns the iterator to the next element past the last element copied
+     @throws see `Input_iterator` and `Output_iterator`
+    */
+    template
+    static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
+                                  alphabet alphabet = alphabet::standard)
+    {
+        constexpr auto pad = '=';
+        const char* alpha  = alphabet == alphabet::url_filename_safe
+                                ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+                                : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+        while (in_begin != in_end) {
+            std::uint8_t i0 = 0, i1 = 0, i2 = 0;
+
+            // first character
+            i0 = static_cast(*in_begin);
+            ++in_begin;
+
+            *out = alpha[i0 >> 2 & 0x3f];
+            ++out;
+
+            // part of first character and second
+            if (in_begin != in_end) {
+                i1 = static_cast(*in_begin);
+                ++in_begin;
+
+                *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
+                ++out;
+            } else {
+                *out = alpha[(i0 & 0x3) << 4];
+                ++out;
+
+                // last padding
+                *out = pad;
+                ++out;
+
+                // last padding
+                *out = pad;
+                ++out;
+
+                break;
+            }
+
+            // part of second character and third
+            if (in_begin != in_end) {
+                i2 = static_cast(*in_begin);
+                ++in_begin;
+
+                *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
+                ++out;
+            } else {
+                *out = alpha[(i1 & 0xf) << 2];
+                ++out;
+
+                // last padding
+                *out = pad;
+                ++out;
+
+                break;
+            }
+
+            // rest of third
+            *out = alpha[i2 & 0x3f];
+            ++out;
+        }
+
+        return out;
+    }
+    /**
+     Encodes a string.
+
+     @param str the string that should be encoded
+     @param alphabet which alphabet should be used
+     @returns the encoded base64 string
+     @throws see base64::encode()
+    */
+    static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
+    {
+        std::string result;
+
+        result.reserve(required_encode_size(str.length()) + 1);
+
+        encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
+
+        return result;
+    }
+    /**
+     Encodes a char array.
+
+     @param buffer the char array
+     @param size the size of the array
+     @param alphabet which alphabet should be used
+     @returns the encoded string
+    */
+    static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
+    {
+        std::string result;
+
+        result.reserve(required_encode_size(size) + 1);
+
+        encode(buffer, buffer + size, std::back_inserter(result), alphabet);
+
+        return result;
+    }
+    /**
+     Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
+     in other words: inplace decoding is possible.
+
+     @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
+     otherwise the behavior depends on the output iterator.
+
+     @tparam Input_iterator the source; the returned elements are cast to `char`
+     @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
+     @param in_begin the beginning of the source
+     @param in_end the ending of the source
+     @param out the destination iterator
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the iterator to the next element past the last element copied
+     @throws base64_error depending on the set behavior
+     @throws see `Input_iterator` and `Output_iterator`
+    */
+    template
+    static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
+                                  alphabet alphabet          = alphabet::auto_,
+                                  decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        //constexpr auto pad = '=';
+        std::uint8_t last  = 0;
+        auto bits          = 0;
+
+        while (in_begin != in_end) {
+            auto c = *in_begin;
+            ++in_begin;
+
+            if (c == '=') {
+                break;
+            }
+
+            auto part = _base64_value(alphabet, c);
+
+            // enough bits for one byte
+            if (bits + 6 >= 8) {
+                *out = (last << (8 - bits)) | (part >> (bits - 2));
+                ++out;
+
+                bits -= 2;
+            } else {
+                bits += 6;
+            }
+
+            last = part;
+        }
+
+        // check padding
+        if (behavior != decoding_behavior::loose) {
+            while (in_begin != in_end) {
+                auto c = *in_begin;
+                ++in_begin;
+
+                if (c != '=') {
+                    throw base64_error("invalid base64 character.");
+                }
+            }
+        }
+
+        return out;
+    }
+    /**
+     Decodes a string.
+
+     @param str the base64 encoded string
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the decoded string
+     @throws see base64::decode()
+    */
+    static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
+                              decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        std::string result;
+
+        result.reserve(max_decode_size(str.length()));
+
+        decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
+
+        return result;
+    }
+    /**
+     Decodes a string.
+
+     @param buffer the base64 encoded buffer
+     @param size the size of the buffer
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the decoded string
+     @throws see base64::decode()
+    */
+    static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
+                              decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        std::string result;
+
+        result.reserve(max_decode_size(size));
+
+        decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
+
+        return result;
+    }
+    /**
+     Decodes a string inplace.
+
+     @param[in,out] str the base64 encoded string
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @throws base64::decode_inplace()
+    */
+    static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
+                               decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
+    }
+    /**
+     Decodes a char array inplace.
+
+     @param[in,out] str the string array
+     @param size the length of the array
+     @param alphabet which alphabet should be used
+     @param behavior the behavior when an error was detected
+     @returns the pointer to the next element past the last element decoded
+     @throws base64::decode_inplace()
+    */
+    static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
+                                decoding_behavior behavior = decoding_behavior::moderate)
+    {
+        return decode(str, str + size, str, alphabet, behavior);
+    }
+    /**
+     Returns the required decoding size for a given size. The value is calculated with the following formula:
+
+     $$
+     \lceil \frac{size}{4} \rceil \cdot 3
+     $$
+
+     @param size the size of the encoded input
+     @returns the size of the resulting decoded buffer; this the absolute maximum
+    */
+    static std::size_t max_decode_size(std::size_t size) noexcept
+    {
+        return (size / 4 + (size % 4 ? 1 : 0)) * 3;
+    }
+    /**
+     Returns the required encoding size for a given size. The value is calculated with the following formula:
+
+     $$
+     \lceil \frac{size}{3} \rceil \cdot 4
+     $$
+
+     @param size the size of the decoded input
+     @returns the size of the resulting encoded buffer
+    */
+    static std::size_t required_encode_size(std::size_t size) noexcept
+    {
+        return (size / 3 + (size % 3 ? 1 : 0)) * 4;
+    }
+
+private:
+    static std::uint8_t _base64_value(alphabet& alphabet, char c)
+    {
+        if (c >= 'A' && c <= 'Z') {
+            return c - 'A';
+        } else if (c >= 'a' && c <= 'z') {
+            return c - 'a' + 26;
+        } else if (c >= '0' && c <= '9') {
+            return c - '0' + 52;
+        }
+
+        // comes down to alphabet
+        if (alphabet == alphabet::standard) {
+            if (c == '+') {
+                return 62;
+            } else if (c == '/') {
+                return 63;
+            }
+        } else if (alphabet == alphabet::url_filename_safe) {
+            if (c == '-') {
+                return 62;
+            } else if (c == '_') {
+                return 63;
+            }
+        } // auto detect
+        else {
+            if (c == '+') {
+                alphabet = alphabet::standard;
+
+                return 62;
+            } else if (c == '/') {
+                alphabet = alphabet::standard;
+
+                return 63;
+            } else if (c == '-') {
+                alphabet = alphabet::url_filename_safe;
+
+                return 62;
+            } else if (c == '_') {
+                alphabet = alphabet::url_filename_safe;
+
+                return 63;
+            }
+        }
+
+        throw base64_error("invalid base64 character.");
+    }
+};
+
+#endif // !PUBLIC_DOMAIN_BASE64_HPP_
diff --git a/llama/build-info.cpp b/llama/build-info.cpp
new file mode 100644
index 000000000..b2c1dba73
--- /dev/null
+++ b/llama/build-info.cpp
@@ -0,0 +1,4 @@
+int LLAMA_BUILD_NUMBER = 0;
+char const *LLAMA_COMMIT = "ba1cb19cdd0d92e012e0f6e009e0620f854b6afd";
+char const *LLAMA_COMPILER = "";
+char const *LLAMA_BUILD_TARGET = "";
diff --git a/llama/clip.cpp b/llama/clip.cpp
new file mode 100644
index 000000000..d8cb50938
--- /dev/null
+++ b/llama/clip.cpp
@@ -0,0 +1,2883 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// NOTE: This is modified from clip.cpp only for LLaVA,
+// so there might be still unnecessary artifacts hanging around
+// I'll gradually clean and extend it
+// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
+#include "clip.h"
+#include "ggml.h"
+#include "ggml-cpu.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_SYCL
+#include "ggml-sycl.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
+#ifdef GGML_USE_VULKAN
+#include "ggml-vulkan.h"
+#endif
+
+#define STB_IMAGE_IMPLEMENTATION
+#include "stb_image.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(LLAVA_LOG_OFF)
+#   define LOG_INF(...)
+#   define LOG_WRN(...)
+#   define LOG_ERR(...)
+#   define LOG_DBG(...)
+#else // defined(LLAVA_LOG_OFF)
+#   define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
+#   define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
+#   define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0)
+#   define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
+#endif // defined(LLAVA_LOG_OFF)
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include 
+#if __GLIBCXX__
+#include 
+#include 
+#include 
+#endif
+#endif
+
+//#define CLIP_DEBUG_FUNCTIONS
+
+// RGB uint8 image
+struct clip_image_u8 {
+    int nx;
+    int ny;
+
+    std::vector buf;
+};
+
+// RGB float32 image (NHWC)
+// Memory layout: RGBRGBRGB...
+struct clip_image_f32 {
+    int nx;
+    int ny;
+
+    std::vector buf;
+};
+
+static std::string format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), buf.size());
+}
+
+//
+// key constants
+//
+
+#define KEY_FTYPE               "general.file_type"
+#define KEY_NAME                "general.name"
+#define KEY_DESCRIPTION         "general.description"
+#define KEY_HAS_TEXT_ENC        "clip.has_text_encoder"
+#define KEY_HAS_VIS_ENC         "clip.has_vision_encoder"
+#define KEY_HAS_LLAVA_PROJ      "clip.has_llava_projector"
+#define KEY_HAS_MINICPMV_PROJ   "clip.has_minicpmv_projector"
+#define KEY_MINICPMV_VERSION    "clip.minicpmv_version"
+#define KEY_HAS_QWEN2VL_MERGER  "clip.has_qwen2vl_merger"
+#define KEY_USE_GELU            "clip.use_gelu"
+#define KEY_USE_SILU            "clip.use_silu"
+#define KEY_N_EMBD              "clip.%s.embedding_length"
+#define KEY_N_FF                "clip.%s.feed_forward_length"
+#define KEY_N_BLOCK             "clip.%s.block_count"
+#define KEY_N_HEAD              "clip.%s.attention.head_count"
+#define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
+#define KEY_PROJ_DIM            "clip.%s.projection_dim"
+#define KEY_TOKENS              "tokenizer.ggml.tokens"
+#define KEY_N_POSITIONS         "clip.text.context_length"
+#define KEY_IMAGE_SIZE          "clip.vision.image_size"
+#define KEY_PATCH_SIZE          "clip.vision.patch_size"
+#define KEY_IMAGE_MEAN          "clip.vision.image_mean"
+#define KEY_IMAGE_STD           "clip.vision.image_std"
+#define KEY_PROJ_TYPE           "clip.projector_type"
+
+#define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
+#define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
+#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
+
+
+//
+// tensor name constants
+//
+
+#define TN_TOKEN_EMBD      "%s.token_embd.weight"
+#define TN_POS_EMBD        "%s.position_embd.weight"
+#define TN_CLASS_EMBD      "v.class_embd"
+#define TN_PATCH_EMBD      "v.patch_embd.weight"  // not rename tensor with ".0" postfix for backwrad compat
+#define TN_PATCH_EMBD_1    "v.patch_embd.weight.1"
+#define TN_PATCH_BIAS      "v.patch_embd.bias"
+#define TN_ATTN_K          "%s.blk.%d.attn_k.%s"
+#define TN_ATTN_Q          "%s.blk.%d.attn_q.%s"
+#define TN_ATTN_V          "%s.blk.%d.attn_v.%s"
+#define TN_ATTN_OUTPUT     "%s.blk.%d.attn_out.%s"
+#define TN_FFN_DOWN        "%s.blk.%d.ffn_down.%s"
+#define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
+#define TN_LN_1            "%s.blk.%d.ln1.%s"
+#define TN_LN_2            "%s.blk.%d.ln2.%s"
+#define TN_LN_PRE          "%s.pre_ln.%s"
+#define TN_LN_POST         "%s.post_ln.%s"
+#define TN_TEXT_PROJ       "text_projection.weight"
+#define TN_VIS_PROJ        "visual_projection.weight"
+#define TN_LLAVA_PROJ      "mm.%d.%s"
+#define TN_MVLM_PROJ_MLP   "mm.model.mlp.%d.%s"
+#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
+#define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
+#define TN_IMAGE_NEWLINE   "model.image_newline"
+
+#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
+#define TN_MINICPMV_QUERY "resampler.query"
+#define TN_MINICPMV_PROJ "resampler.proj.weight"
+#define TN_MINICPMV_KV_PROJ "resampler.kv.weight"
+#define TN_MINICPMV_ATTN "resampler.attn.%s.%s"
+#define TN_MINICPMV_LN "resampler.ln_%s.%s"
+
+
+enum projector_type {
+    PROJECTOR_TYPE_MLP,
+    PROJECTOR_TYPE_MLP_NORM,
+    PROJECTOR_TYPE_LDP,
+    PROJECTOR_TYPE_LDPV2,
+    PROJECTOR_TYPE_RESAMPLER,
+    PROJECTOR_TYPE_MERGER,
+    PROJECTOR_TYPE_UNKNOWN,
+};
+
+static std::map PROJECTOR_TYPE_NAMES = {
+    { PROJECTOR_TYPE_MLP, "mlp" },
+    { PROJECTOR_TYPE_LDP, "ldp" },
+    { PROJECTOR_TYPE_LDPV2, "ldpv2"},
+    { PROJECTOR_TYPE_RESAMPLER, "resampler"},
+    { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
+};
+
+
+//
+// utilities to get data from a gguf file
+//
+
+static int get_key_idx(const gguf_context * ctx, const char * key) {
+    int i = gguf_find_key(ctx, key);
+    if (i == -1) {
+        LOG_ERR("key %s not found in file\n", key);
+        throw std::runtime_error(format("Missing required key: %s", key));
+    }
+
+    return i;
+}
+
+static uint32_t get_u32(const gguf_context * ctx, const std::string & key) {
+    const int i = get_key_idx(ctx, key.c_str());
+
+    return gguf_get_val_u32(ctx, i);
+}
+
+static float get_f32(const gguf_context * ctx, const std::string & key) {
+    const int i = get_key_idx(ctx, key.c_str());
+
+    return gguf_get_val_f32(ctx, i);
+}
+
+static struct ggml_tensor * get_tensor(struct ggml_context * ctx, const std::string & name) {
+    struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str());
+    if (!cur) {
+        throw std::runtime_error(format("%s: unable to find tensor %s\n", __func__, name.c_str()));
+    }
+
+    return cur;
+}
+
+static std::string get_ftype(int ftype) {
+    return ggml_type_name(static_cast(ftype));
+}
+
+static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
+    switch (type) {
+        case GGUF_TYPE_UINT8:   return std::to_string(((const uint8_t  *)data)[i]);
+        case GGUF_TYPE_INT8:    return std::to_string(((const int8_t   *)data)[i]);
+        case GGUF_TYPE_UINT16:  return std::to_string(((const uint16_t *)data)[i]);
+        case GGUF_TYPE_INT16:   return std::to_string(((const int16_t  *)data)[i]);
+        case GGUF_TYPE_UINT32:  return std::to_string(((const uint32_t *)data)[i]);
+        case GGUF_TYPE_INT32:   return std::to_string(((const int32_t  *)data)[i]);
+        case GGUF_TYPE_UINT64:  return std::to_string(((const uint64_t *)data)[i]);
+        case GGUF_TYPE_INT64:   return std::to_string(((const int64_t  *)data)[i]);
+        case GGUF_TYPE_FLOAT32: return std::to_string(((const float    *)data)[i]);
+        case GGUF_TYPE_FLOAT64: return std::to_string(((const double   *)data)[i]);
+        case GGUF_TYPE_BOOL:    return ((const bool *)data)[i] ? "true" : "false";
+        default:                return format("unknown type %d", type);
+    }
+}
+
+static void replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    if (search.empty()) {
+        return;
+    }
+    std::string builder;
+    builder.reserve(s.length());
+    size_t pos = 0;
+    size_t last_pos = 0;
+    while ((pos = s.find(search, last_pos)) != std::string::npos) {
+        builder.append(s, last_pos, pos - last_pos);
+        builder.append(replace);
+        last_pos = pos + search.length();
+    }
+    builder.append(s, last_pos, std::string::npos);
+    s = std::move(builder);
+}
+
+static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
+    const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
+
+    switch (type) {
+        case GGUF_TYPE_STRING:
+            return gguf_get_val_str(ctx_gguf, i);
+        case GGUF_TYPE_ARRAY:
+            {
+                const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
+                int arr_n = gguf_get_arr_n(ctx_gguf, i);
+                const void * data = gguf_get_arr_data(ctx_gguf, i);
+                std::stringstream ss;
+                ss << "[";
+                for (int j = 0; j < arr_n; j++) {
+                    if (arr_type == GGUF_TYPE_STRING) {
+                        std::string val = gguf_get_arr_str(ctx_gguf, i, j);
+                        // escape quotes
+                        replace_all(val, "\\", "\\\\");
+                        replace_all(val, "\"", "\\\"");
+                        ss << '"' << val << '"';
+                    } else if (arr_type == GGUF_TYPE_ARRAY) {
+                        ss << "???";
+                    } else {
+                        ss << gguf_data_to_str(arr_type, data, j);
+                    }
+                    if (j < arr_n - 1) {
+                        ss << ", ";
+                    }
+                }
+                ss << "]";
+                return ss.str();
+            }
+        default:
+            return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
+    }
+}
+
+static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") {
+    size_t tensor_size = ggml_nbytes(tensor);
+    LOG_INF("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n",
+            prefix, ggml_n_dims(tensor), tensor->name, tensor_size,
+            tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], ggml_type_name(tensor->type));
+}
+
+static projector_type clip_projector_type_from_string(const std::string & name) {
+    for (const auto & kv : PROJECTOR_TYPE_NAMES) { // NOLINT
+        if (kv.second == name) {
+            return kv.first;
+        }
+    }
+    return PROJECTOR_TYPE_UNKNOWN;
+}
+
+#ifdef CLIP_DEBUG_FUNCTIONS
+static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    // PPM header: P6 format, width, height, and max color value
+    file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
+
+    // Write pixel data
+    for (size_t i = 0; i < img.buf.size(); i += 3) {
+        // PPM expects binary data in RGB format, which matches our image buffer
+        file.write(reinterpret_cast(&img.buf[i]), 3);
+    }
+
+    file.close();
+}
+
+static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
+    int bytesPerPixel = 3;
+    int widthInBytes = img.nx * bytesPerPixel;
+    int paddingAmount = (4 - (widthInBytes % 4)) % 4;
+    int stride = widthInBytes + paddingAmount;
+
+    // Bitmap file header
+    unsigned char fileHeader[14] = {
+        'B','M',     // Signature
+        0,0,0,0,    // Image file size in bytes
+        0,0,0,0,    // Reserved
+        54,0,0,0    // Start of pixel array
+    };
+
+    // Total file size
+    fileSize = 54 + (stride * img.ny);
+    fileHeader[2] = (unsigned char)(fileSize);
+    fileHeader[3] = (unsigned char)(fileSize >> 8);
+    fileHeader[4] = (unsigned char)(fileSize >> 16);
+    fileHeader[5] = (unsigned char)(fileSize >> 24);
+
+    // Bitmap information header (BITMAPINFOHEADER)
+    unsigned char infoHeader[40] = {
+        40,0,0,0,   // Size of this header (40 bytes)
+        0,0,0,0,    // Image width
+        0,0,0,0,    // Image height
+        1,0,        // Number of color planes
+        24,0,       // Bits per pixel
+        0,0,0,0,    // No compression
+        0,0,0,0,    // Image size (can be 0 for no compression)
+        0,0,0,0,    // X pixels per meter (not specified)
+        0,0,0,0,    // Y pixels per meter (not specified)
+        0,0,0,0,    // Total colors (color table not used)
+        0,0,0,0     // Important colors (all are important)
+    };
+
+    // Width and height in the information header
+    infoHeader[4] = (unsigned char)(img.nx);
+    infoHeader[5] = (unsigned char)(img.nx >> 8);
+    infoHeader[6] = (unsigned char)(img.nx >> 16);
+    infoHeader[7] = (unsigned char)(img.nx >> 24);
+    infoHeader[8] = (unsigned char)(img.ny);
+    infoHeader[9] = (unsigned char)(img.ny >> 8);
+    infoHeader[10] = (unsigned char)(img.ny >> 16);
+    infoHeader[11] = (unsigned char)(img.ny >> 24);
+
+    // Write file headers
+    file.write(reinterpret_cast(fileHeader), sizeof(fileHeader));
+    file.write(reinterpret_cast(infoHeader), sizeof(infoHeader));
+
+    // Pixel data
+    std::vector padding(3, 0); // Max padding size to be added to each row
+    for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
+        for (int x = 0; x < img.nx; ++x) {
+            // Each pixel
+            size_t pixelIndex = (y * img.nx + x) * 3;
+            unsigned char pixel[3] = {
+                img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
+                img.buf[pixelIndex + 1],
+                img.buf[pixelIndex]
+            };
+            file.write(reinterpret_cast(pixel), 3);
+        }
+        // Write padding for the row
+        file.write(reinterpret_cast(padding.data()), paddingAmount);
+    }
+
+    file.close();
+}
+
+// debug function to convert f32 to u8
+static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(3 * src.nx * src.ny);
+    for (size_t i = 0; i < src.buf.size(); ++i) {
+        dst.buf[i] = static_cast(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
+    }
+}
+#endif
+
+
+//
+// clip layers
+//
+
+struct clip_hparams {
+    int32_t image_size;
+    int32_t patch_size;
+    int32_t hidden_size;
+    int32_t n_intermediate;
+    int32_t projection_dim;
+    int32_t n_head;
+    int32_t n_layer;
+
+    float eps;
+
+    char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default)
+
+    int32_t image_grid_pinpoints[32];
+    int32_t image_crop_resolution;
+};
+
+struct clip_layer {
+    // attention
+    struct ggml_tensor * k_w;
+    struct ggml_tensor * k_b;
+    struct ggml_tensor * q_w;
+    struct ggml_tensor * q_b;
+    struct ggml_tensor * v_w;
+    struct ggml_tensor * v_b;
+
+    struct ggml_tensor * o_w;
+    struct ggml_tensor * o_b;
+
+    // layernorm 1
+    struct ggml_tensor * ln_1_w;
+    struct ggml_tensor * ln_1_b;
+
+    // ff
+    struct ggml_tensor * ff_i_w;
+    struct ggml_tensor * ff_i_b;
+
+    struct ggml_tensor * ff_o_w;
+    struct ggml_tensor * ff_o_b;
+
+    // layernorm 2
+    struct ggml_tensor * ln_2_w;
+    struct ggml_tensor * ln_2_b;
+};
+
+struct clip_vision_model {
+    struct clip_hparams hparams;
+
+    // embeddings
+    struct ggml_tensor * class_embedding;
+    struct ggml_tensor * patch_embeddings_0;
+    struct ggml_tensor * patch_embeddings_1;  // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
+    struct ggml_tensor * patch_bias;
+    struct ggml_tensor * position_embeddings;
+
+    struct ggml_tensor * pre_ln_w;
+    struct ggml_tensor * pre_ln_b;
+
+    std::vector layers;
+
+    struct ggml_tensor * post_ln_w;
+    struct ggml_tensor * post_ln_b;
+
+    struct ggml_tensor * projection;
+
+    // LLaVA projection
+    struct ggml_tensor * mm_0_w = NULL;
+    struct ggml_tensor * mm_0_b = NULL;
+    struct ggml_tensor * mm_2_w = NULL;
+    struct ggml_tensor * mm_2_b = NULL;
+
+    struct ggml_tensor * image_newline = NULL;
+
+    // Yi type models with mlp+normalization projection
+    struct ggml_tensor * mm_1_w = NULL; // Yi type models have 0, 1, 3, 4
+    struct ggml_tensor * mm_1_b = NULL;
+    struct ggml_tensor * mm_3_w = NULL;
+    struct ggml_tensor * mm_3_b = NULL;
+    struct ggml_tensor * mm_4_w = NULL;
+    struct ggml_tensor * mm_4_b = NULL;
+
+    // MobileVLM projection
+    struct ggml_tensor * mm_model_mlp_1_w;
+    struct ggml_tensor * mm_model_mlp_1_b;
+    struct ggml_tensor * mm_model_mlp_3_w;
+    struct ggml_tensor * mm_model_mlp_3_b;
+    struct ggml_tensor * mm_model_block_1_block_0_0_w;
+    struct ggml_tensor * mm_model_block_1_block_0_1_w;
+    struct ggml_tensor * mm_model_block_1_block_0_1_b;
+    struct ggml_tensor * mm_model_block_1_block_1_fc1_w;
+    struct ggml_tensor * mm_model_block_1_block_1_fc1_b;
+    struct ggml_tensor * mm_model_block_1_block_1_fc2_w;
+    struct ggml_tensor * mm_model_block_1_block_1_fc2_b;
+    struct ggml_tensor * mm_model_block_1_block_2_0_w;
+    struct ggml_tensor * mm_model_block_1_block_2_1_w;
+    struct ggml_tensor * mm_model_block_1_block_2_1_b;
+    struct ggml_tensor * mm_model_block_2_block_0_0_w;
+    struct ggml_tensor * mm_model_block_2_block_0_1_w;
+    struct ggml_tensor * mm_model_block_2_block_0_1_b;
+    struct ggml_tensor * mm_model_block_2_block_1_fc1_w;
+    struct ggml_tensor * mm_model_block_2_block_1_fc1_b;
+    struct ggml_tensor * mm_model_block_2_block_1_fc2_w;
+    struct ggml_tensor * mm_model_block_2_block_1_fc2_b;
+    struct ggml_tensor * mm_model_block_2_block_2_0_w;
+    struct ggml_tensor * mm_model_block_2_block_2_1_w;
+    struct ggml_tensor * mm_model_block_2_block_2_1_b;
+
+    // MobileVLM_V2 projection
+    struct ggml_tensor * mm_model_mlp_0_w;
+    struct ggml_tensor * mm_model_mlp_0_b;
+    struct ggml_tensor * mm_model_mlp_2_w;
+    struct ggml_tensor * mm_model_mlp_2_b;
+    struct ggml_tensor * mm_model_peg_0_w;
+    struct ggml_tensor * mm_model_peg_0_b;
+
+    // MINICPMV projection
+    struct ggml_tensor * mm_model_pos_embed_k;
+    struct ggml_tensor * mm_model_query;
+    struct ggml_tensor * mm_model_proj;
+    struct ggml_tensor * mm_model_kv_proj;
+    struct ggml_tensor * mm_model_attn_q_w;
+    struct ggml_tensor * mm_model_attn_q_b;
+    struct ggml_tensor * mm_model_attn_k_w;
+    struct ggml_tensor * mm_model_attn_k_b;
+    struct ggml_tensor * mm_model_attn_v_w;
+    struct ggml_tensor * mm_model_attn_v_b;
+    struct ggml_tensor * mm_model_attn_o_w;
+    struct ggml_tensor * mm_model_attn_o_b;
+    struct ggml_tensor * mm_model_ln_q_w;
+    struct ggml_tensor * mm_model_ln_q_b;
+    struct ggml_tensor * mm_model_ln_kv_w;
+    struct ggml_tensor * mm_model_ln_kv_b;
+    struct ggml_tensor * mm_model_ln_post_w;
+    struct ggml_tensor * mm_model_ln_post_b;
+};
+
+struct clip_ctx {
+    bool has_text_encoder    = false;
+    bool has_vision_encoder  = false;
+    bool has_llava_projector = false;
+    bool has_minicpmv_projector = false;
+    bool has_qwen2vl_merger = false;
+    int minicpmv_version = 2;
+
+    struct clip_vision_model vision_model;
+    projector_type proj_type = PROJECTOR_TYPE_MLP;
+
+    float image_mean[3];
+    float image_std[3];
+    bool use_gelu = false;
+    bool use_silu = false;
+    int32_t ftype = 1;
+
+    bool has_class_embedding = true;
+    bool has_pre_norm = true;
+    bool has_post_norm = false;
+    bool has_patch_bias = false;
+
+    struct gguf_context * ctx_gguf;
+    struct ggml_context * ctx_data;
+
+    std::vector buf_compute_meta;
+
+    // memory buffers to evaluate the model
+    ggml_backend_buffer_t params_buffer  = NULL;
+
+    ggml_backend_t backend       = NULL;
+    ggml_gallocr_t compute_alloc = NULL;
+
+    struct clip_image_size * load_image_size;
+};
+
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
+    if (!ctx->has_vision_encoder) {
+        LOG_ERR("This gguf file seems to have no vision encoder\n");
+        return nullptr;
+    }
+
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    int image_size_width  = image_size;
+    int image_size_height = image_size;
+    if (ctx->has_minicpmv_projector) {
+        if (load_image_size == nullptr) {
+            load_image_size = clip_image_size_init();
+        }
+        LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height);
+        image_size_width  = load_image_size->width;
+        image_size_height = load_image_size->height;
+        if (is_inf) {
+            image_size_width  = imgs->data->nx;
+            image_size_height = imgs->data->ny;
+        }
+    }
+    else if (ctx->has_qwen2vl_merger) {
+        // use the image's native resolution when image is avaible
+        if (is_inf) {
+        // if (imgs->data->nx && imgs->data->ny) {
+            image_size_width  = imgs->data->nx;
+            image_size_height = imgs->data->ny;
+        }
+    }
+    const int patch_size           = hparams.patch_size;
+    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int patches_w            = image_size_width / patch_size;
+    const int patches_h            = image_size_height / patch_size;
+    const int num_positions        = num_patches + (ctx->has_class_embedding ? 1 : 0);
+    const int num_position_ids     = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
+    const int hidden_size          = hparams.hidden_size;
+    const int n_head               = hparams.n_head;
+    const int d_head               = hidden_size / n_head;
+    int n_layer                    = hparams.n_layer;
+    const float eps                = hparams.eps;
+    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+    const int batch_size = imgs->size;
+
+    if (ctx->has_llava_projector || ctx->has_minicpmv_projector) {
+        GGML_ASSERT(batch_size == 1);
+    }
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx0 = ggml_init(params);
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+    if (ctx->has_qwen2vl_merger) {
+        GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
+        GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
+
+        auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+        inp = ggml_add(ctx0, inp, inp_1);
+        inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
+        inp = ggml_reshape_4d(
+            ctx0, inp,
+            hidden_size * 2, patches_w / 2, patches_h, batch_size);
+        inp = ggml_reshape_4d(
+            ctx0, inp,
+            hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
+        inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+        inp = ggml_reshape_3d(
+            ctx0, inp,
+            hidden_size, patches_w * patches_h, batch_size);
+    }
+    else {
+        inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
+        inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
+    }
+
+    if (ctx->has_patch_bias) {
+        // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
+        inp = ggml_add(ctx0, inp, model.patch_bias);
+    }
+    struct ggml_tensor * embeddings = inp;
+    struct ggml_tensor * pos_embed = nullptr;
+
+    if (ctx->has_llava_projector) {
+        // concat class_embeddings and patch_embeddings
+        if (ctx->has_class_embedding) {
+            embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+            ggml_set_name(embeddings, "embeddings");
+            ggml_set_input(embeddings);
+            embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
+            embeddings = ggml_acc(ctx0, embeddings, inp,
+                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
+        }
+    }
+
+    struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+    ggml_set_name(positions, "positions");
+    ggml_set_input(positions);
+
+    if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
+        embeddings =
+            ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
+    }
+
+    if (ctx->has_minicpmv_projector) {
+        int pos_w = image_size_width/patch_size;
+        int pos_h = image_size_height/patch_size;
+        if (ctx->minicpmv_version == 2) {
+            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
+        }
+        else if (ctx->minicpmv_version == 3) {
+            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
+        }
+        ggml_set_name(pos_embed, "pos_embed");
+        ggml_set_input(pos_embed);
+    }
+
+    // pre-layernorm
+    if (ctx->has_pre_norm) {
+        embeddings = ggml_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "pre_ln");
+
+        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
+    }
+
+    // loop over layers
+    if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) {
+        // TODO: figure out why we doing thing in this way ???
+        n_layer += 1;
+    }
+    for (int il = 0; il < n_layer - 1; il++) {
+        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+
+        //const size_t nb_q_w = model.layers[il].q_w->nb[0];
+
+        // layernorm1
+        {
+            cur = ggml_norm(ctx0, cur, eps);
+
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
+                           model.layers[il].ln_1_b);
+        }
+
+        // self-attention
+        {
+
+            struct ggml_tensor * Q =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+
+            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
+            if (ctx->has_qwen2vl_merger) {
+                Q = ggml_rope_multi(
+                    ctx0, Q, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+            }
+            Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
+            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+            Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
+
+            struct ggml_tensor * K =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+
+            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+            if (ctx->has_qwen2vl_merger) {
+                K = ggml_rope_multi(
+                    ctx0, K, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+            }
+            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+
+            struct ggml_tensor * V =
+                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+
+            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            KQ = ggml_soft_max_inplace(ctx0, KQ);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
+            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size);
+        }
+
+        // attention output
+        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, embeddings);
+
+        embeddings = cur; // embeddings = residual, cur = hidden_states
+
+        // layernorm2
+        {
+            cur = ggml_norm(ctx0, cur, eps);
+
+            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
+        }
+
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
+
+        if (ctx->use_gelu) {
+            cur = ggml_gelu_inplace(ctx0, cur);
+        } else if (ctx->use_silu) {
+            cur = ggml_silu_inplace(ctx0, cur);
+        } else {
+            cur = ggml_gelu_quick_inplace(ctx0, cur);
+        }
+
+        cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
+        cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
+
+        // residual 2
+        cur = ggml_add(ctx0, embeddings, cur);
+
+        embeddings = cur;
+
+    }
+
+    // post-layernorm
+    if (ctx->has_post_norm) {
+        embeddings = ggml_norm(ctx0, embeddings, eps);
+        ggml_set_name(embeddings, "post_ln");
+
+        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
+    }
+
+    // llava projector
+    if (ctx->has_llava_projector) {
+        embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
+
+        struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
+        ggml_set_name(patches, "patches");
+        ggml_set_input(patches);
+
+        // shape [1, 576, 1024]
+        // ne is whcn, ne = [1024, 576, 1, 1]
+        embeddings = ggml_get_rows(ctx0, embeddings, patches);
+
+        // print_tensor_info(embeddings, "embeddings");
+
+        // llava projector
+        if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+            embeddings = ggml_gelu(ctx0, embeddings);
+            embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+        }
+        else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+            // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
+            // First LayerNorm
+            embeddings = ggml_norm(ctx0, embeddings, eps);
+            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
+                                model.mm_1_b);
+
+            // GELU activation
+            embeddings = ggml_gelu(ctx0, embeddings);
+
+            // Second linear layer
+            embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
+            embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
+
+            // Second LayerNorm
+            embeddings = ggml_norm(ctx0, embeddings, eps);
+            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
+                                model.mm_4_b);
+        }
+        else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+            // MobileVLM projector
+            int n_patch = 24;
+            struct ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
+            mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
+            mlp_1 = ggml_gelu(ctx0, mlp_1);
+            struct ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
+            mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
+            // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
+
+            // block 1
+            struct ggml_tensor * block_1 = nullptr;
+            {
+                // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
+                mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
+                mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
+                // stride = 1, padding = 1, bias is nullptr
+                block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
+
+                // layer norm
+                // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+
+                // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                // hardswish
+                struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                // pointwise conv
+                block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
+                block_1 = ggml_relu(ctx0, block_1);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
+                block_1 = ggml_hardsigmoid(ctx0, block_1);
+                // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
+                block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                int w = block_1->ne[0], h = block_1->ne[1];
+                block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+
+                // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
+                block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+                // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                // residual
+                block_1 = ggml_add(ctx0, mlp_3, block_1);
+            }
+
+            // block_2
+            {
+                // stride = 2
+                block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
+
+                // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                // layer norm
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                // hardswish
+                struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                // not sure the parameters is right for globalAvgPooling
+                block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                // pointwise conv
+                block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
+                block_1 = ggml_relu(ctx0, block_1);
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
+                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
+                block_1 = ggml_hardsigmoid(ctx0, block_1);
+
+                // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                int w = block_1->ne[0], h = block_1->ne[1];
+                block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+                // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
+                block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+
+                // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                block_1 = ggml_norm(ctx0, block_1, eps);
+                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
+                block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
+                // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
+            }
+            embeddings = block_1;
+        }
+        else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
+        {
+            int n_patch = 24;
+            struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+            mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
+            mlp_0 = ggml_gelu(ctx0, mlp_0);
+            struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
+            mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
+            // mlp_2 ne = [2048, 576, 1, 1]
+            // // AVG Pool Layer 2*2, strides = 2
+            mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
+            // mlp_2 ne = [576, 2048, 1, 1]
+            mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
+            // mlp_2 ne [24, 24, 2048, 1]
+            mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
+            // weight ne = [3, 3, 2048, 1]
+            struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
+            peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
+            peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
+            mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
+            peg_0 = ggml_add(ctx0, peg_0, mlp_2);
+            peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
+            embeddings = peg_0;
+        }
+        else {
+            GGML_ABORT("fatal error");
+        }
+    }
+    // minicpmv projector
+    else if (ctx->has_minicpmv_projector)
+    {
+        if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+            struct ggml_tensor * q = model.mm_model_query;
+            { // layernorm
+                q = ggml_norm(ctx0, q, eps);
+                q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+            }
+            struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+            { // layernorm
+                v = ggml_norm(ctx0, v, eps);
+                v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
+            }
+            struct ggml_tensor * k;
+            { // position
+                // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
+                k = ggml_add(ctx0, v, pos_embed);
+            }
+
+            { // attention
+                int hidden_size = 4096;
+                const int d_head = 128;
+                int n_head = hidden_size/d_head;
+                int num_query = 96;
+                if (ctx->minicpmv_version == 2) {
+                    hidden_size = 4096;
+                    n_head = hidden_size/d_head;
+                    num_query = 96;
+                }
+                else if (ctx->minicpmv_version == 3) {
+                    hidden_size = 3584;
+                    n_head = hidden_size/d_head;
+                    num_query = 64;
+                }
+
+                struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
+                Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
+                struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
+                struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
+                // permute
+                Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
+                Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+                Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
+                K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+                K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+                K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+                V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+                V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+                V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+                KQ = ggml_soft_max_inplace(ctx0, KQ);
+                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+                KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
+                KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+                KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
+
+                embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
+            }
+            { // layernorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
+            }
+            embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+        }
+        else {
+            GGML_ASSERT(false);
+        }
+    }
+    else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+        embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
+
+        embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+        // GELU activation
+        embeddings = ggml_gelu(ctx0, embeddings);
+
+        // Second linear layer
+        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, embeddings);
+
+    ggml_free(ctx0);
+
+    return gf;
+}
+
+// read and create ggml_context containing the tensors and their data
+struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+    struct ggml_context * meta = NULL;
+
+    struct gguf_init_params params = {
+        /*.no_alloc = */ true,
+        /*.ctx      = */ &meta,
+    };
+
+    struct gguf_context * ctx = gguf_init_from_file(fname, params);
+    if (!ctx) {
+        throw std::runtime_error(format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
+    }
+
+    if (verbosity >= 1) {
+        const int n_tensors = gguf_get_n_tensors(ctx);
+        const int n_kv = gguf_get_n_kv(ctx);
+        const int ftype = get_u32(ctx, KEY_FTYPE);
+        const std::string ftype_str = get_ftype(ftype);
+        const int idx_desc = get_key_idx(ctx, KEY_DESCRIPTION);
+        const std::string description = gguf_get_val_str(ctx, idx_desc);
+        const int idx_name = gguf_find_key(ctx, KEY_NAME);
+        if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug
+            const std::string name = gguf_get_val_str(ctx, idx_name);
+            LOG_INF("%s: model name:   %s\n", __func__, name.c_str());
+        }
+        LOG_INF("%s: description:  %s\n", __func__, description.c_str());
+        LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx));
+        LOG_INF("%s: alignment:    %zu\n", __func__, gguf_get_alignment(ctx));
+        LOG_INF("%s: n_tensors:    %d\n", __func__, n_tensors);
+        LOG_INF("%s: n_kv:         %d\n", __func__, n_kv);
+        LOG_INF("%s: ftype:        %s\n", __func__, ftype_str.c_str());
+        LOG_INF("\n");
+    }
+    const int n_tensors = gguf_get_n_tensors(ctx);
+
+    // kv
+    const int n_kv = gguf_get_n_kv(ctx);
+    LOG_INF("%s: loaded meta data with %d key-value pairs and %d tensors from %s\n",
+        __func__, n_kv, n_tensors, fname);
+    {
+        std::map n_type;
+
+        for (int i = 0; i < n_tensors; i++) {
+            enum ggml_type type = gguf_get_tensor_type(ctx, i);
+
+            n_type[type]++;
+        }
+
+        LOG_INF("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
+        for (int i = 0; i < n_kv; i++) {
+            const char * name           = gguf_get_key(ctx, i);
+            const enum gguf_type type   = gguf_get_kv_type(ctx, i);
+            const std::string type_name =
+                type == GGUF_TYPE_ARRAY
+                ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(ctx, i)), gguf_get_arr_n(ctx, i))
+                : gguf_type_name(type);
+
+            std::string value          = gguf_kv_to_str(ctx, i);
+            const size_t MAX_VALUE_LEN = 40;
+            if (value.size() > MAX_VALUE_LEN) {
+                value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
+            }
+            replace_all(value, "\n", "\\n");
+
+            LOG_INF("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
+        }
+
+        // print type counts
+        for (auto & kv : n_type) {
+            if (kv.second == 0) {
+                continue;
+            }
+
+            LOG_INF("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
+        }
+    }
+
+    // data
+    size_t model_size = 0;
+    {
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            const size_t offset = gguf_get_tensor_offset(ctx, i);
+            enum ggml_type type = gguf_get_tensor_type(ctx, i);
+            struct ggml_tensor * cur = ggml_get_tensor(meta, name);
+            size_t tensor_size = ggml_nbytes(cur);
+            model_size += tensor_size;
+            if (verbosity >= 3) {
+                LOG_INF("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
+                       __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
+            }
+        }
+    }
+
+    clip_ctx * new_clip = new clip_ctx{};
+
+    // update projector type
+    {
+        int idx = gguf_find_key(ctx, KEY_PROJ_TYPE);
+        if (idx != -1) {
+            const std::string proj_type = gguf_get_val_str(ctx, idx);
+            new_clip->proj_type = clip_projector_type_from_string(proj_type);
+        } else {
+            new_clip->proj_type = PROJECTOR_TYPE_MLP;
+        }
+
+        if (new_clip->proj_type == PROJECTOR_TYPE_MLP) {
+            if (gguf_find_tensor(ctx, format(TN_LLAVA_PROJ, 3, "weight").c_str()) != -1) {
+                new_clip->proj_type = PROJECTOR_TYPE_MLP_NORM;
+            }
+        }
+    }
+
+#ifdef GGML_USE_CUDA
+   new_clip->backend = ggml_backend_cuda_init(0);
+   LOG_INF("%s: CLIP using CUDA backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_METAL
+   new_clip->backend = ggml_backend_metal_init();
+   LOG_INF("%s: CLIP using Metal backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_CANN
+   new_clip->backend = ggml_backend_cann_init(0);
+   LOG_INF("%s: CLIP using CANN backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_VULKAN
+   new_clip->backend = ggml_backend_vk_init(0);
+   LOG_INF("%s: CLIP using Vulkan backend\n", __func__);
+#endif
+
+#ifdef GGML_USE_SYCL
+   new_clip->backend = ggml_backend_sycl_init(0);
+   LOG_INF("%s: CLIP using SYCL backend\n", __func__);
+#endif
+
+    if (!new_clip->backend) {
+        new_clip->backend = ggml_backend_cpu_init();
+        LOG_INF("%s: CLIP using CPU backend\n", __func__);
+    }
+
+    // model size and capabilities
+    {
+        int idx = get_key_idx(ctx, KEY_HAS_TEXT_ENC);
+        new_clip->has_text_encoder = gguf_get_val_bool(ctx, idx);
+
+        idx = get_key_idx(ctx, KEY_HAS_VIS_ENC);
+        new_clip->has_vision_encoder = gguf_get_val_bool(ctx, idx);
+
+        idx = gguf_find_key(ctx, KEY_HAS_LLAVA_PROJ);
+        if (idx != -1) {
+            new_clip->has_llava_projector = gguf_get_val_bool(ctx, idx);
+        }
+
+        idx = gguf_find_key(ctx, KEY_HAS_MINICPMV_PROJ);
+        if (idx != -1) {
+            new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
+        }
+
+        idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
+        if (idx != -1) {
+            new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
+        }
+
+        idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER);
+        if (idx != -1) {
+            new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx);
+        }
+        // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
+
+        GGML_ASSERT(new_clip->has_vision_encoder);
+        GGML_ASSERT(!new_clip->has_text_encoder);
+
+        idx = get_key_idx(ctx, KEY_USE_GELU);
+        new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
+
+        try {
+            idx = get_key_idx(ctx, KEY_USE_SILU);
+            new_clip->use_silu = gguf_get_val_bool(ctx, idx);
+        } catch (std::runtime_error & /*e*/) {
+            new_clip->use_silu = false;
+        }
+
+        if (verbosity >= 1) {
+            LOG_INF("%s: text_encoder:   %d\n", __func__, new_clip->has_text_encoder);
+            LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder);
+            LOG_INF("%s: llava_projector:  %d\n", __func__, new_clip->has_llava_projector);
+            LOG_INF("%s: minicpmv_projector:  %d\n", __func__, new_clip->has_minicpmv_projector);
+            LOG_INF("%s: model size:     %.2f MB\n", __func__, model_size / 1024.0 / 1024.0);
+            LOG_INF("%s: metadata size:  %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0);
+        }
+    }
+
+    LOG_INF("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors);
+
+    // load tensors
+    {
+        std::vector read_buf;
+        struct ggml_init_params params = {
+            /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc =*/ true,
+        };
+
+        new_clip->ctx_data = ggml_init(params);
+        if (!new_clip->ctx_data) {
+            LOG_ERR("%s: ggml_init() failed\n", __func__);
+            clip_free(new_clip);
+            gguf_free(ctx);
+            return nullptr;
+        }
+#ifdef _WIN32
+        int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0);
+        if (!wlen) {
+            return NULL;
+        }
+        wchar_t * wbuf = (wchar_t *) malloc(wlen * sizeof(wchar_t));
+        wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wbuf, wlen);
+        if (!wlen) {
+            free(wbuf);
+            return NULL;
+        }
+#if __GLIBCXX__
+        int fd = _wopen(wbuf, _O_RDONLY | _O_BINARY);
+        __gnu_cxx::stdio_filebuf buffer(fd, std::ios_base::in);
+        std::istream fin(&buffer);
+#else // MSVC
+        // unused in our current build
+        auto fin = std::ifstream(wbuf, std::ios::binary);
+#endif
+        free(wbuf);
+#else
+        auto fin = std::ifstream(fname, std::ios::binary);
+#endif
+        if (!fin) {
+            LOG_ERR("cannot open model file for loading tensors\n");
+            clip_free(new_clip);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        // add tensors to context
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor * t = ggml_get_tensor(meta, name);
+            struct ggml_tensor * cur = ggml_dup_tensor(new_clip->ctx_data, t);
+            ggml_set_name(cur, name);
+        }
+
+        // alloc memory and offload data
+        new_clip->params_buffer = ggml_backend_alloc_ctx_tensors(new_clip->ctx_data, new_clip->backend);
+        for (int i = 0; i < n_tensors; ++i) {
+            const char * name = gguf_get_tensor_name(ctx, i);
+            struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
+            const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
+            fin.seekg(offset, std::ios::beg);
+            if (!fin) {
+                LOG_ERR("%s: failed to seek for tensor %s\n", __func__, name);
+                clip_free(new_clip);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            int num_bytes = ggml_nbytes(cur);
+            if (ggml_backend_buffer_is_host(new_clip->params_buffer)) {
+                // for the CPU and Metal backend, we can read directly into the tensor
+                fin.read(reinterpret_cast(cur->data), num_bytes);
+            } else {
+                // read into a temporary buffer first, then copy to device memory
+                read_buf.resize(num_bytes);
+                fin.read(reinterpret_cast(read_buf.data()), num_bytes);
+                ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+            }
+        }
+#if defined(_WIN32) && defined(__GLIBCXX__)
+        close(fd);
+#else
+        fin.close();
+#endif
+    }
+
+    // vision model
+    if (new_clip->has_vision_encoder) {
+        // load vision model
+        auto & vision_model = new_clip->vision_model;
+        auto & hparams = vision_model.hparams;
+        hparams.hidden_size    = get_u32(ctx, format(KEY_N_EMBD, "vision"));
+        hparams.n_head         = get_u32(ctx, format(KEY_N_HEAD, "vision"));
+        hparams.n_intermediate = get_u32(ctx, format(KEY_N_FF, "vision"));
+        hparams.n_layer        = get_u32(ctx, format(KEY_N_BLOCK, "vision"));
+        hparams.image_size     = get_u32(ctx, KEY_IMAGE_SIZE);
+        hparams.patch_size     = get_u32(ctx, KEY_PATCH_SIZE);
+        hparams.projection_dim = get_u32(ctx, format(KEY_PROJ_DIM, "vision"));
+        hparams.eps            = get_f32(ctx, format(KEY_LAYER_NORM_EPS, "vision"));
+
+        try {
+            int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS);
+            int n = gguf_get_arr_n(ctx, idx);
+            const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx);
+            for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) {
+                hparams.image_grid_pinpoints[i] = pinpoints[i];
+            }
+            if (n < 32)
+                hparams.image_grid_pinpoints[n] = 0;
+        } catch (std::runtime_error & /*e*/) {
+            hparams.image_grid_pinpoints[0]=0;
+        }
+
+        try {
+            int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE);
+            strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx));
+        } catch (std::runtime_error & /*e*/) {
+            strcpy(hparams.mm_patch_merge_type, "flat");
+        }
+
+        try {
+            hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6
+        } catch(const std::exception& /*e*/) {
+            hparams.image_crop_resolution = hparams.image_size;
+        }
+
+        int idx_mean = get_key_idx(ctx, KEY_IMAGE_MEAN);
+        int idx_std  = get_key_idx(ctx, KEY_IMAGE_STD);
+
+        const float * mean_data = (const float *)gguf_get_arr_data(ctx, idx_mean);
+        const float * std_data  = (const float *)gguf_get_arr_data(ctx, idx_std);
+
+        for (int i = 0; i < 3; ++i) {
+            new_clip->image_mean[i] = mean_data[i];
+            new_clip->image_std[i]  = std_data[i];
+        }
+
+        if (verbosity >= 2) {
+            LOG_INF("\n%s: vision model hparams\n", __func__);
+            LOG_INF("image_size         %d\n", hparams.image_size);
+            LOG_INF("patch_size         %d\n", hparams.patch_size);
+            LOG_INF("v_hidden_size      %d\n", hparams.hidden_size);
+            LOG_INF("v_n_intermediate   %d\n", hparams.n_intermediate);
+            LOG_INF("v_projection_dim   %d\n", hparams.projection_dim);
+            LOG_INF("v_n_head           %d\n", hparams.n_head);
+            LOG_INF("v_n_layer          %d\n", hparams.n_layer);
+            LOG_INF("v_eps              %f\n", hparams.eps);
+            LOG_INF("v_image_mean       %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]);
+            LOG_INF("v_image_std        %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]);
+            LOG_INF("v_image_grid_pinpoints: ");
+            for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) {
+                LOG_INF("%d ", hparams.image_grid_pinpoints[i]);
+            }
+            LOG_INF("\n");
+            LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type);
+
+        }
+
+        try {
+            vision_model.class_embedding  = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD);
+            new_clip->has_class_embedding = true;
+        } catch (const std::exception& /*e*/) {
+            new_clip->has_class_embedding = false;
+        }
+
+        try {
+            vision_model.pre_ln_w  = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight"));
+            vision_model.pre_ln_b  = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias"));
+            new_clip->has_pre_norm = true;
+        } catch (std::exception & /*e*/) {
+            new_clip->has_pre_norm = false;
+        }
+
+        try {
+            vision_model.post_ln_w  = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
+            vision_model.post_ln_b  = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias"));
+            new_clip->has_post_norm = true;
+        } catch (std::exception & /*e*/) {
+            new_clip->has_post_norm = false;
+        }
+
+        try {
+            vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS);
+            new_clip->has_patch_bias = true;
+        } catch (std::exception & /*e*/) {
+            new_clip->has_patch_bias = false;
+        }
+
+        try {
+            vision_model.patch_embeddings_0    = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
+            vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
+        } catch(const std::exception& /*e*/) {
+            LOG_ERR("%s: failed to load vision model tensors\n", __func__);
+        }
+        try {
+            vision_model.patch_embeddings_1    = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
+        } catch(const std::exception& /*e*/) {
+            new_clip->has_qwen2vl_merger = false;
+        }
+
+        // LLaVA projection
+        if (new_clip->proj_type == PROJECTOR_TYPE_MLP || new_clip->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+            vision_model.mm_0_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
+            vision_model.mm_0_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
+            try {
+                // Yi-type llava
+                vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight"));
+                vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                // missing in Yi-type llava
+                vision_model.mm_2_w              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
+                vision_model.mm_2_b              = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                // Yi-type llava
+                vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight"));
+                vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                // Yi-type llava
+                vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight"));
+                vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias"));
+            } catch (std::runtime_error & /*e*/) { }
+            try {
+                vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE);
+                // LOG_INF("%s: image_newline tensor (llava-1.6) found\n", __func__);
+            } catch (std::runtime_error & /*e*/) { }
+        } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) {
+            // MobileVLM projection
+            vision_model.mm_model_mlp_1_w               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight"));
+            vision_model.mm_model_mlp_1_b               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "bias"));
+            vision_model.mm_model_mlp_3_w               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "weight"));
+            vision_model.mm_model_mlp_3_b               = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 3, "bias"));
+            vision_model.mm_model_block_1_block_0_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
+            vision_model.mm_model_block_1_block_0_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
+            vision_model.mm_model_block_1_block_0_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
+            vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
+            vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
+            vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
+            vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
+            vision_model.mm_model_block_1_block_2_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
+            vision_model.mm_model_block_1_block_2_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
+            vision_model.mm_model_block_1_block_2_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
+            vision_model.mm_model_block_2_block_0_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
+            vision_model.mm_model_block_2_block_0_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
+            vision_model.mm_model_block_2_block_0_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
+            vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
+            vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
+            vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
+            vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
+            vision_model.mm_model_block_2_block_2_0_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
+            vision_model.mm_model_block_2_block_2_1_w   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
+            vision_model.mm_model_block_2_block_2_1_b   = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
+        }
+        else if (new_clip->proj_type == PROJECTOR_TYPE_LDPV2)
+        {
+            // MobilVLM_V2 projection
+            vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "weight"));
+            vision_model.mm_model_mlp_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 0, "bias"));
+            vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "weight"));
+            vision_model.mm_model_mlp_2_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 2, "bias"));
+            vision_model.mm_model_peg_0_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "weight"));
+            vision_model.mm_model_peg_0_b = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_PEG, 0, "bias"));
+        }
+        else if (new_clip->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+            // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
+            vision_model.mm_model_pos_embed_k = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD_K);
+            vision_model.mm_model_query = get_tensor(new_clip->ctx_data, TN_MINICPMV_QUERY);
+            vision_model.mm_model_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_PROJ);
+            vision_model.mm_model_kv_proj = get_tensor(new_clip->ctx_data, TN_MINICPMV_KV_PROJ);
+            vision_model.mm_model_attn_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "weight"));
+            vision_model.mm_model_attn_k_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "weight"));
+            vision_model.mm_model_attn_v_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "weight"));
+            vision_model.mm_model_attn_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "q", "bias"));
+            vision_model.mm_model_attn_k_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "k", "bias"));
+            vision_model.mm_model_attn_v_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "v", "bias"));
+            vision_model.mm_model_attn_o_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "weight"));
+            vision_model.mm_model_attn_o_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_ATTN, "out", "bias"));
+            vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "weight"));
+            vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "q", "bias"));
+            vision_model.mm_model_ln_kv_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "weight"));
+            vision_model.mm_model_ln_kv_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "kv", "bias"));
+            vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight"));
+            vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias"));
+        }
+        else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) {
+            vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight"));
+            vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias"));
+            vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
+            vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
+        }
+        else {
+            std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
+            throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
+        }
+
+        vision_model.layers.resize(hparams.n_layer);
+
+        for (int il = 0; il < hparams.n_layer; ++il) {
+            auto & layer = vision_model.layers[il];
+            layer.k_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_K,      "v", il, "weight"));
+            layer.q_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q,      "v", il, "weight"));
+            layer.v_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_V,      "v", il, "weight"));
+            layer.o_w    = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "weight"));
+            layer.ln_1_w = get_tensor(new_clip->ctx_data, format(TN_LN_1,        "v", il, "weight"));
+            layer.ln_2_w = get_tensor(new_clip->ctx_data, format(TN_LN_2,        "v", il, "weight"));
+            layer.ff_i_w = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN,    "v", il, "weight"));
+            layer.ff_o_w = get_tensor(new_clip->ctx_data, format(TN_FFN_UP,      "v", il, "weight"));
+            layer.k_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_K,      "v", il, "bias"));
+            layer.q_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_Q,      "v", il, "bias"));
+            layer.v_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_V,      "v", il, "bias"));
+            layer.o_b    = get_tensor(new_clip->ctx_data, format(TN_ATTN_OUTPUT, "v", il, "bias"));
+            layer.ln_1_b = get_tensor(new_clip->ctx_data, format(TN_LN_1,        "v", il, "bias"));
+            layer.ln_2_b = get_tensor(new_clip->ctx_data, format(TN_LN_2,        "v", il, "bias"));
+            layer.ff_i_b = get_tensor(new_clip->ctx_data, format(TN_FFN_DOWN,    "v", il, "bias"));
+            layer.ff_o_b = get_tensor(new_clip->ctx_data, format(TN_FFN_UP,      "v", il, "bias"));
+        }
+    }
+
+    ggml_free(meta);
+
+    new_clip->ctx_gguf = ctx;
+
+    // measure mem requirement and allocate
+    {
+        new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
+        new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
+        clip_image_f32_batch batch;
+        batch.size = 1;
+        batch.data = nullptr;
+        ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false);
+        ggml_gallocr_reserve(new_clip->compute_alloc, gf);
+        size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
+        LOG_INF("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
+    }
+
+    return new_clip;
+}
+
+void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
+    ctx_clip->load_image_size = load_image_size;
+}
+
+struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
+    return ctx_clip->load_image_size;
+}
+
+struct clip_image_size * clip_image_size_init() {
+    struct clip_image_size * load_image_size = new struct clip_image_size();
+    load_image_size->width = 448;
+    load_image_size->height = 448;
+    return load_image_size;
+}
+
+struct clip_image_u8 * clip_image_u8_init() {
+    return new clip_image_u8();
+}
+
+struct clip_image_f32 * clip_image_f32_init() {
+    return new clip_image_f32();
+}
+
+void clip_image_u8_free(struct clip_image_u8  * img) { delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch  * batch) {
+    if (batch->size > 0) {
+        delete[] batch->data;
+        batch->size = 0;
+    }
+}
+void clip_image_f32_batch_free(struct clip_image_f32_batch  * batch) {
+    if (batch->size > 0) {
+        delete[] batch->data;
+        batch->size = 0;
+    }
+}
+
+static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
+    img->nx = nx;
+    img->ny = ny;
+    img->buf.resize(3 * nx * ny);
+    memcpy(img->buf.data(), data, img->buf.size());
+}
+
+bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
+    int nx, ny, nc;
+    auto * data = stbi_load(fname, &nx, &ny, &nc, 3);
+    if (!data) {
+        LOG_ERR("%s: failed to load image '%s'\n", __func__, fname);
+        return false;
+    }
+    build_clip_img_from_data(data, nx, ny, img);
+    stbi_image_free(data);
+    return true;
+}
+
+bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) {
+    int nx, ny, nc;
+    auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
+    if (!data) {
+        LOG_ERR("%s: failed to decode image bytes\n", __func__);
+        return false;
+    }
+    build_clip_img_from_data(data, nx, ny, img);
+    stbi_image_free(data);
+    return true;
+}
+
+// Linear interpolation between two points
+inline float clip_lerp(float s, float e, float t) {
+    return s + (e - s) * t;
+}
+// Bilinear resize function
+static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+    dst.nx = target_width;
+    dst.ny = target_height;
+    dst.buf.resize(3 * target_width * target_height);
+
+    float x_ratio = static_cast(src.nx - 1) / target_width;
+    float y_ratio = static_cast(src.ny - 1) / target_height;
+
+    for (int y = 0; y < target_height; y++) {
+        for (int x = 0; x < target_width; x++) {
+            float px = x_ratio * x;
+            float py = y_ratio * y;
+            int x_floor = static_cast(px);
+            int y_floor = static_cast(py);
+            float x_lerp = px - x_floor;
+            float y_lerp = py - y_floor;
+
+            for (int c = 0; c < 3; c++) {
+                float top = clip_lerp(
+                    static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
+                    static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
+                    x_lerp
+                );
+                float bottom = clip_lerp(
+                    static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
+                    static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
+                    x_lerp
+                );
+                dst.buf[3 * (y * target_width + x) + c] = static_cast(clip_lerp(top, bottom, y_lerp));
+            }
+        }
+    }
+}
+
+// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
+static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) {
+    dst->nx = src->nx;
+    dst->ny = src->ny;
+    dst->buf.resize(src->buf.size());
+
+    for (size_t i = 0; i < src->buf.size(); ++i) {
+        int c = i % 3; // rgb
+        dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c];
+    }
+}
+
+inline int clip(int x, int lower, int upper) {
+    return std::max(lower, std::min(x, upper));
+}
+
+static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) {
+    const int nx = img.nx;
+    const int ny = img.ny;
+
+    dst.nx = target_width;
+    dst.ny = target_height;
+    dst.buf.resize(3 * target_width * target_height);
+
+    float Cc;
+    float C[5];
+    float d0, d2, d3, a0, a1, a2, a3;
+    int i, j, k, jj;
+    int x, y;
+    float dx, dy;
+    float tx, ty;
+
+    tx = (float)nx / (float)target_width;
+    ty = (float)ny / (float)target_height;
+
+    // Bicubic interpolation; adapted from ViT.cpp, inspired from :
+    //    -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
+    //    -> https://en.wikipedia.org/wiki/Bicubic_interpolation
+
+    for (i = 0; i < target_height; i++) {
+        for (j = 0; j < target_width; j++) {
+            x = (int)(tx * j);
+            y = (int)(ty * i);
+
+            dx = tx * j - x;
+            dy = ty * i - y;
+
+            for (k = 0; k < 3; k++) {
+                for (jj = 0; jj <= 3; jj++) {
+                    d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                    d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                    d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                    a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+
+                    a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                    a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                    a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+
+                    C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
+
+                    d0 = C[0] - C[1];
+                    d2 = C[2] - C[1];
+                    d3 = C[3] - C[1];
+                    a0 = C[1];
+                    a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                    a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                    a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+                    Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
+
+                    const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
+                    dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+// llava-1.6 type of resize_and_pad (black)
+static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) {
+    int target_width = target_resolution.first;
+    int target_height = target_resolution.second;
+
+    float scale_w = static_cast(target_width) / image.nx;
+    float scale_h = static_cast(target_height) / image.ny;
+
+    int new_width, new_height;
+
+    if (scale_w < scale_h) {
+        new_width = target_width;
+        new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height);
+    } else {
+        new_height = target_height;
+        new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width);
+    }
+
+    clip_image_u8 resized_image;
+    // bilinear_resize(image, resized_image, new_width, new_height);
+    bicubic_resize(image, resized_image, new_width, new_height);
+
+    clip_image_u8 padded_image;
+    padded_image.nx = target_width;
+    padded_image.ny = target_height;
+    padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black
+
+    // Calculate padding offsets
+    int pad_x = (target_width - new_width) / 2;
+    int pad_y = (target_height - new_height) / 2;
+
+    // Copy the resized image into the center of the padded buffer
+    for (int y = 0; y < new_height; ++y) {
+        for (int x = 0; x < new_width; ++x) {
+            for (int c = 0; c < 3; ++c) {
+                padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
+            }
+        }
+    }
+    image_output = std::move(padded_image);
+}
+
+/**
+ * Selects the best resolution from a list of possible resolutions based on the original size.
+ *
+ * @param original_size The original size of the image in the format (width, height).
+ * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+ * @return The best fit resolution in the format (width, height).
+ */
+static std::pair select_best_resolution(const std::pair & original_size, const std::vector> & possible_resolutions) {
+    int original_width = original_size.first;
+    int original_height = original_size.second;
+    std::pair best_fit;
+    int max_effective_resolution = 0;
+    int min_wasted_resolution = std::numeric_limits::max();
+
+    for (const auto& resolution : possible_resolutions) {
+        int width = resolution.first;
+        int height = resolution.second;
+        float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height);
+        int downscaled_width = static_cast(original_width * scale);
+        int downscaled_height = static_cast(original_height * scale);
+        int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
+        int wasted_resolution = (width * height) - effective_resolution;
+        // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
+        if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
+            max_effective_resolution = effective_resolution;
+            min_wasted_resolution = wasted_resolution;
+            best_fit = resolution;
+        }
+    }
+
+    return best_fit;
+}
+
+static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
+    std::vector patches;
+    int width = image.nx;
+    int height = image.ny;
+    for (int i = 0; i < height; i += patch_size) {
+        for (int j = 0; j < width; j += patch_size) {
+            clip_image_u8 *patch = clip_image_u8_init();
+            patch->nx = std::min(patch_size, width - j);
+            patch->ny = std::min(patch_size, height - i);
+            patch->buf.resize(3 * patch->nx * patch->ny);
+            for (int y = 0; y < patch->ny; ++y) {
+                for (int x = 0; x < patch->nx; ++x) {
+                    for (int c = 0; c < 3; ++c) {
+                        patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c];
+                    }
+                }
+            }
+            patches.push_back(patch);
+        }
+    }
+    return patches;
+}
+
+static int ensure_divide(int length, int patch_size) {
+    return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size);
+}
+
+static std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
+    int width = original_size.first;
+    int height = original_size.second;
+    if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
+        float r = static_cast(width) / height;
+        height = static_cast(scale_resolution / std::sqrt(r));
+        width = static_cast(height * r);
+    }
+    int best_width = ensure_divide(width, patch_size);
+    int best_height = ensure_divide(height, patch_size);
+    return std::make_pair(best_width, best_height);
+}
+
+static std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
+    int width, height;
+    std::tie(width, height) = original_size;
+    int grid_x, grid_y;
+    std::tie(grid_x, grid_y) = grid;
+
+    int refine_width = ensure_divide(width, grid_x);
+    int refine_height = ensure_divide(height, grid_y);
+
+    int grid_width = refine_width / grid_x;
+    int grid_height = refine_height / grid_y;
+
+   // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
+    auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
+    int best_grid_width, best_grid_height;
+    std::tie(best_grid_width, best_grid_height) = best_grid_size;
+
+  //  std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
+    std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line)
+    return refine_size;
+}
+
+static std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
+    std::vector candidate_split_grids_nums;
+    for (int i : {multiple - 1, multiple, multiple + 1}) {
+        if (i == 1 || i > max_slice_nums) {
+            continue;
+        }
+        candidate_split_grids_nums.push_back(i);
+    }
+
+    std::vector> candidate_grids;
+    for (int split_grids_nums : candidate_split_grids_nums) {
+        int m = 1;
+        while (m <= split_grids_nums) {
+            if (split_grids_nums % m == 0) {
+                candidate_grids.emplace_back(m, split_grids_nums / m);
+            }
+            ++m;
+        }
+    }
+
+    std::pair best_grid{1, 1};
+    float min_error = std::numeric_limits::infinity();
+    for (const auto& grid : candidate_grids) {
+        float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second));
+        if (error < min_error) {
+            best_grid = grid;
+            min_error = error;
+        }
+    }
+    return best_grid;
+}
+
+// inspired from LLaVA-UHD:
+//    -> https://arxiv.org/pdf/2403.11703
+//    -> https://github.com/thunlp/LLaVA-UHD
+//    -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
+static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
+    const std::pair original_size={img->nx,img->ny};
+    const int original_width = img->nx;
+    const int original_height = img->ny;
+    const float log_ratio = log(1.0*original_width/original_height);
+    const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
+    const int multiple = fmin(ceil(ratio), max_slice_nums);
+
+    std::vector> images;
+    LOG_INF("%s: multiple %d\n", __func__, multiple);
+    images.push_back(std::vector());
+
+    if (multiple <= 1) {
+        auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
+        clip_image_u8 * source_image = clip_image_u8_init();
+        bicubic_resize(*img, *source_image, best_size.first, best_size.second);
+        // source_image = image.resize(best_size, Image.Resampling.BICUBIC)
+        images[images.size()-1].push_back(source_image);
+    }
+    else if (multiple > 1) {
+        auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
+        clip_image_u8 * source_image = clip_image_u8_init();
+        bicubic_resize(*img, *source_image, best_size.first, best_size.second);
+        // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
+        LOG_INF("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
+        images[images.size()-1].push_back(source_image);
+
+        std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
+        LOG_INF("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
+
+        auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
+        clip_image_u8 * refine_image = clip_image_u8_init();
+        bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
+
+        LOG_INF("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
+
+        // split_to_patches
+        int width = refine_image->nx;
+        int height = refine_image->ny;
+        int grid_x = int(width / best_grid.first);
+        int grid_y = int(height / best_grid.second);
+        for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
+            images.push_back(std::vector());
+            for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
+                clip_image_u8 * patch = clip_image_u8_init();
+                patch->nx = grid_x;
+                patch->ny = grid_y;
+                patch->buf.resize(3 * patch->nx * patch->ny);
+                for (int y = patches_i; y < patches_i + grid_y; ++y) {
+                    for (int x = patches_j; x < patches_j + grid_x; ++x) {
+                        const int i = 3 * (y * refine_image->nx + x);
+                        const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j));
+                        patch->buf[j]   = refine_image->buf[i];
+                        patch->buf[j+1] = refine_image->buf[i+1];
+                        patch->buf[j+2] = refine_image->buf[i+2];
+                    }
+                }
+                images[images.size()-1].push_back(patch);
+            }
+        }
+    }
+    return images;
+}
+
+int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
+    const int max_slice_nums=9;
+    const int scale_resolution=448;
+    const int original_width = ctx_clip->load_image_size->width;
+    const int original_height = ctx_clip->load_image_size->height;
+    const float log_ratio = log(1.0*original_width/original_height);
+    const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
+    const int multiple = fmin(ceil(ratio), max_slice_nums);
+    std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
+    return best_grid.first;
+}
+
+// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
+// res_imgs memory is being allocated here, previous allocations will be freed if found
+bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
+
+    if(clip_is_minicpmv(ctx)){
+        int max_slice_nums = 9;
+        std::vector> imgs = uhd_slice_image(img, max_slice_nums);
+        res_imgs->size = 0;
+        for (size_t i = 0; i < imgs.size(); ++i){
+            res_imgs->size += imgs[i].size();
+        }
+        res_imgs->data = new clip_image_f32[res_imgs->size];
+        int idx = 0;
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            for (size_t j = 0; j < imgs[i].size(); ++j) {
+                LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
+                clip_image_f32 * res = clip_image_f32_init();
+                normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std);
+                res_imgs->data[idx++] = *res;
+                clip_image_f32_free(res);
+            }
+        }
+        return true;
+    }
+    else if (ctx->has_qwen2vl_merger) {
+        clip_image_u8 * resized = clip_image_u8_init();
+        auto patch_size = clip_patch_size(ctx) * 2;
+        int nx = ceil((float)img->nx / patch_size) * patch_size;
+        int ny = ceil((float)img->ny / patch_size) * patch_size;
+        bicubic_resize(*img, *resized, nx, ny);
+
+        res_imgs->data = new clip_image_f32[1];
+        // clip_image_f32 * res = clip_image_f32_init();
+        normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std);
+        // res_imgs->data[0] = *res;
+        res_imgs->size = 1;
+
+        // clip_image_f32_free(res);
+        clip_image_u8_free(resized);
+        return true;
+    }
+
+    bool pad_to_square = true;
+    if (!ctx->has_vision_encoder) {
+        LOG_ERR("This gguf file seems to have no vision encoder\n");
+        return false;
+    }
+    auto & params = ctx->vision_model.hparams;
+    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
+    if (strcmp(params.mm_patch_merge_type, "spatial_unpad") == 0) {
+        pad_to_square = false;
+    }
+    // free the previous res_imgs if any set
+    if (res_imgs->size > 0) {
+        clip_image_f32_batch_free(res_imgs);
+    }
+    res_imgs->data = nullptr;
+    res_imgs->size = 0;
+
+    // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
+    // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+
+    clip_image_u8 * temp = clip_image_u8_init(); // we will keep the input image data here temporarily
+    if (pad_to_square && img->nx != img->ny) {
+        int longer_side = std::max(img->nx, img->ny);
+        temp->nx = longer_side;
+        temp->ny = longer_side;
+        temp->buf.resize(3 * longer_side * longer_side);
+        const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
+
+        // fill with background color
+        for (size_t i = 0; i < temp->buf.size(); i++) {
+            temp->buf[i] = bc[i % 3];
+        }
+
+        // copy from the input image
+        for (int y = 0; y < img->ny; y++) {
+            for (int x = 0; x < img->nx; x++) {
+                const int i = 3 * (y * img->nx + x);
+                const int j = 3 * (y * temp->nx + x);
+                temp->buf[j]   = img->buf[i];
+                temp->buf[j+1] = img->buf[i+1];
+                temp->buf[j+2] = img->buf[i+2];
+            }
+        }
+    } else {
+        if (params.image_grid_pinpoints[0] != 0) {
+            // "spatial_unpad" with "anyres" processing for llava-1.6
+            std::vector> possible_resolutions;
+            for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) {
+                possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
+            }
+            std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
+            // clip_image_save_to_bmp(*img, "input.bmp");
+            resize_and_pad_image(*img, *temp, best_resolution);  // we do not pad with mean-bg color anymore in llava-1.6
+            // clip_image_save_to_bmp(*temp, "resized.bmp");
+            // visually verify normalized image:
+            // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
+            // {
+            //     clip_image_u8 * temp2 = clip_image_u8_init();
+            //     clip_image_convert_f32_to_u8(*res, *temp2);
+            //     clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp");
+            //     clip_image_u8_free(temp2);
+            // }
+
+            std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
+
+            clip_image_u8 *image_original_resize = clip_image_u8_init();
+            // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
+            bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
+            patches.insert(patches.begin(), image_original_resize);
+            // clip_image_f32_batch_init(patches.size());
+            res_imgs->size = patches.size();
+            res_imgs->data = new clip_image_f32[res_imgs->size];
+            int num=0;
+            for (auto& patch : patches) {
+                normalize_image_u8_to_f32(patch, &res_imgs->data[num], ctx->image_mean, ctx->image_std);
+                num++;
+            }
+
+            for (size_t i = 0; i < patches.size(); i++) {
+                // LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny);
+                clip_image_u8_free(patches[i]);
+            }
+
+            clip_image_u8_free(temp);
+
+            return true;
+        } else {
+            temp->nx = img->nx;
+            temp->ny = img->ny;
+            temp->buf.resize(img->buf.size());
+            memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
+        }
+    }
+
+    const int nx = temp->nx;
+    const int ny = temp->ny;
+    // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
+
+    const int nx2 = ctx->vision_model.hparams.image_size;
+    const int ny2 = ctx->vision_model.hparams.image_size;
+    clip_image_f32 * res = clip_image_f32_init();
+    res->nx = nx2;
+    res->ny = ny2;
+    res->buf.resize(3 * nx2 * ny2);
+
+    const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size;
+
+    const int nx3 = int(nx / scale + 0.5f);
+    const int ny3 = int(ny / scale + 0.5f);
+
+    const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
+    const auto & s3 = ctx->image_std;  // {0.26862954f, 0.26130258f, 0.27577711f};
+
+    for (int y = 0; y < ny3; y++) {
+        for (int x = 0; x < nx3; x++) {
+            for (int c = 0; c < 3; c++) {
+                // linear interpolation
+                const float sx = (x + 0.5f) * scale - 0.5f;
+                const float sy = (y + 0.5f) * scale - 0.5f;
+
+                const int x0 = std::max(0, (int)std::floor(sx));
+                const int y0 = std::max(0, (int)std::floor(sy));
+
+                const int x1 = std::min(x0 + 1, nx - 1);
+                const int y1 = std::min(y0 + 1, ny - 1);
+
+                const float dx = sx - x0;
+                const float dy = sy - y0;
+
+                const int j00 = 3 * (y0 * nx + x0) + c;
+                const int j01 = 3 * (y0 * nx + x1) + c;
+                const int j10 = 3 * (y1 * nx + x0) + c;
+                const int j11 = 3 * (y1 * nx + x1) + c;
+
+                const float v00 = temp->buf[j00];
+                const float v01 = temp->buf[j01];
+                const float v10 = temp->buf[j10];
+                const float v11 = temp->buf[j11];
+
+                const float v0 = v00 * (1.0f - dx) + v01 * dx;
+                const float v1 = v10 * (1.0f - dx) + v11 * dx;
+
+                const float v = v0 * (1.0f - dy) + v1 * dy;
+
+                const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
+
+                const int i = 3 * (y * nx3 + x) + c;
+
+                res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
+            }
+        }
+    }
+    clip_image_u8_free(temp);
+
+    // {
+    //     clip_image_u8 * temp2 = clip_image_u8_init();
+    //     clip_image_convert_f32_to_u8(*res, *temp2);
+    //     clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp");
+    //     clip_image_u8_free(temp2);
+    // }
+    // res_imgs.push_back(res);
+
+    res_imgs->size = 1;
+    res_imgs->data = new clip_image_f32[res_imgs->size];
+    res_imgs->data[0] = *res;
+    clip_image_f32_free(res);
+
+    return true;
+}
+
+ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
+    return ctx->vision_model.image_newline;
+}
+
+void clip_free(clip_ctx * ctx) {
+    ggml_free(ctx->ctx_data);
+    gguf_free(ctx->ctx_gguf);
+
+    ggml_backend_buffer_free(ctx->params_buffer);
+    ggml_backend_free(ctx->backend);
+    ggml_gallocr_free(ctx->compute_alloc);
+    delete ctx;
+}
+
+size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
+    return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
+}
+
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
+    clip_image_f32 img;
+    img.nx = img_w;
+    img.ny = img_h;
+    return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
+}
+
+int32_t clip_image_size(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.image_size;
+}
+
+int32_t clip_patch_size(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.patch_size;
+}
+
+int32_t clip_hidden_size(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.hidden_size;
+}
+
+const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.mm_patch_merge_type;
+}
+
+const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
+    return ctx->vision_model.hparams.image_grid_pinpoints;
+}
+
+int clip_n_patches(const struct clip_ctx * ctx) {
+    clip_image_f32 img;
+    img.nx = ctx->vision_model.hparams.image_size;
+    img.ny = ctx->vision_model.hparams.image_size;
+    return clip_n_patches_by_img(ctx, &img);
+}
+
+int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->vision_model.hparams;
+
+    int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
+
+    if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
+        n_patches /= 4;
+    } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+        if (ctx->minicpmv_version == 2) {
+            n_patches = 96;
+        }
+        else if (ctx->minicpmv_version == 3) {
+            n_patches = 64;
+        }
+    } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+        int patch_size = params.patch_size * 2;
+        int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
+        int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+        n_patches = x_patch * y_patch;
+    }
+
+    return n_patches;
+}
+
+static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) {
+    assert(embed_dim % 2 == 0);
+    int H = pos.size();
+    int W = pos[0].size();
+
+    std::vector omega(embed_dim / 2);
+    for (int i = 0; i < embed_dim / 2; ++i) {
+        omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2));
+    }
+
+    std::vector>> emb(H, std::vector>(W, std::vector(embed_dim)));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                float out_value = pos[h][w] * omega[d];
+                emb[h][w][d] = sin(out_value);
+                emb[h][w][d + embed_dim / 2] = cos(out_value);
+            }
+        }
+    }
+
+    return emb;
+}
+
+static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>> & grid) {
+    assert(embed_dim % 2 == 0);
+    std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2)
+    std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2)
+
+    int H = emb_h.size();
+    int W = emb_h[0].size();
+    std::vector>> emb(H, std::vector>(W, std::vector(embed_dim)));
+
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                emb[h][w][d] = emb_h[h][w][d];
+                emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
+            }
+        }
+    }
+    return emb;
+}
+
+static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) {
+    int grid_h_size = image_size.first;
+    int grid_w_size = image_size.second;
+
+    std::vector grid_h(grid_h_size);
+    std::vector grid_w(grid_w_size);
+
+    for (int i = 0; i < grid_h_size; ++i) {
+        grid_h[i] = static_cast(i);
+    }
+    for (int i = 0; i < grid_w_size; ++i) {
+        grid_w[i] = static_cast(i);
+    }
+
+    std::vector> grid(grid_h_size, std::vector(grid_w_size));
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid[h][w] = grid_w[w];
+        }
+    }
+    std::vector>> grid_2d = {grid, grid};
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid_2d[0][h][w] = grid_h[h];
+            grid_2d[1][h][w] = grid_w[w];
+        }
+    }
+
+    std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
+
+    int H = image_size.first;
+    int W = image_size.second;
+    std::vector> pos_embed_2d(H * W, std::vector(embed_dim));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            pos_embed_2d[w * H + h] = pos_embed_3d[h][w];
+        }
+    }
+
+    return pos_embed_2d;
+}
+
+bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
+    if (!ctx->has_vision_encoder) {
+        LOG_ERR("This gguf file seems to have no vision encoder\n");
+        return false;
+    }
+
+    clip_image_f32_batch imgs{};
+    imgs.size = 1;
+    imgs.data = img;
+    return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
+}
+
+bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) {
+    if (!ctx->has_vision_encoder) {
+        LOG_ERR("This gguf file seems to have no vision encoder\n");
+        return false;
+    }
+
+    int batch_size = imgs->size;
+    if (ctx->has_llava_projector) {
+        GGML_ASSERT(batch_size == 1); // TODO: support multiple images
+    }
+    if (ctx->has_minicpmv_projector) {
+        GGML_ASSERT(batch_size == 1);
+    }
+
+    // build the inference graph
+    ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
+    ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
+
+    // set inputs
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    const int image_size = hparams.image_size;
+    int image_size_width  = image_size;
+    int image_size_height = image_size;
+    if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) {
+        image_size_width  = imgs->data[0].nx;
+        image_size_height = imgs->data[0].ny;
+    }
+    const int patch_size    = hparams.patch_size;
+    const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
+    if(ctx->load_image_size==nullptr){
+        ctx->load_image_size= clip_image_size_init();
+    }
+    const int pos_w = ctx->load_image_size->width/patch_size;
+    const int pos_h = ctx->load_image_size->height/patch_size;
+
+    {
+        struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
+        float * data = (float *)malloc(ggml_nbytes(inp_raw));
+
+        for (size_t i = 0; i < imgs->size; i++) {
+            const int nx = imgs->data[i].nx;
+            const int ny = imgs->data[i].ny;
+            if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
+                GGML_ASSERT(nx == image_size && ny == image_size);
+            }
+
+            const int n = nx * ny;
+
+            for (int b = 0; b < batch_size; b++) {
+                for (int k = 0; k < 3; k++) {
+                    for (int y = 0; y < ny; y++) {
+                        for (int x = 0; x < nx; x++) {
+                            data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
+                        }
+                    }
+                }
+            }
+        }
+        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+        free(data);
+    }
+    if (ctx->has_minicpmv_projector) {
+        {
+            // inspired from siglip:
+            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
+            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
+            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+            int bucket_coords_h[70];
+            int bucket_coords_w[70];
+            for (int i = 0; i < pos_h; i++){
+                bucket_coords_h[i] = std::floor(70.0*i/pos_h);
+            }
+            for (int i = 0; i < pos_w; i++){
+                bucket_coords_w[i] = std::floor(70.0*i/pos_w);
+            }
+            for (int i = 0, id = 0; i < pos_h; i++){
+                for (int j = 0; j < pos_w; j++){
+                    positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+                }
+            }
+            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+            free(positions_data);
+        }
+
+        {
+            // inspired from resampler of Qwen-VL:
+            //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main
+            //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
+            struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
+            int embed_dim = 4096;
+            if (ctx->minicpmv_version == 2) {
+                embed_dim = 4096;
+            }
+            else if (ctx->minicpmv_version == 3) {
+                embed_dim = 3584;
+            }
+            auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
+
+            float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
+            for(int i=0;i < pos_w * pos_h; ++i){
+                for(int j=0; j < embed_dim; ++j){
+                    pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
+                }
+            }
+
+            ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
+            free(pos_embed_data);
+        }
+    }
+    else{
+        {
+            if (ctx->has_class_embedding) {
+                struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+
+                void* zero_mem = malloc(ggml_nbytes(embeddings));
+                memset(zero_mem, 0, ggml_nbytes(embeddings));
+                ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
+                free(zero_mem);
+            }
+        }
+
+        if (ctx->has_qwen2vl_merger) {
+            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+            const int pw = image_size_width / patch_size;
+            const int ph = image_size_height / patch_size;
+            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+
+            int ptr = 0;
+            for (int y = 0; y < ph; y+=2)
+            {
+                for (int x = 0; x < pw; x+=2)
+                {
+                    for (int dy = 0; dy < 2; dy++) {
+                        for (int dx = 0; dx < 2; dx++) {
+                            positions_data[ptr]                 = y + dy;
+                            positions_data[num_patches + ptr]     = x + dx;
+                            positions_data[num_patches * 2 + ptr] = y + dy;
+                            positions_data[num_patches * 3 + ptr] = x + dx;
+                            ptr++;
+                        }
+                    }
+                }
+            }
+
+            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+            free(positions_data);
+        }
+        else {
+            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
+
+            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+            for (int i = 0; i < num_positions; i++) {
+                positions_data[i] = i;
+            }
+            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
+            free(positions_data);
+
+            {
+                struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
+                int* patches_data = (int*)malloc(ggml_nbytes(patches));
+                for (int i = 0; i < num_patches; i++) {
+                    patches_data[i] = i + 1;
+                }
+                ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
+                free(patches_data);
+            }
+        }
+    }
+
+    if (ggml_backend_is_cpu(ctx->backend)) {
+        ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
+    }
+
+    ggml_backend_graph_compute(ctx->backend, gf);
+
+    // the last node is the embedding tensor
+    struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
+
+    // copy the embeddings to the location passed by the user
+    ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
+
+    return true;
+}
+
+bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) {
+    ggml_type type = GGML_TYPE_Q4_1;
+
+    assert(itype < GGML_TYPE_COUNT);
+    type = static_cast(itype);
+
+    auto * ctx_clip = clip_model_load(fname_inp, 2);
+
+    const auto & ctx_src = ctx_clip->ctx_gguf;
+    const auto & ctx_data = ctx_clip->ctx_data;
+
+    auto * ctx_out = gguf_init_empty();
+    gguf_set_kv(ctx_out, ctx_src);
+    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
+    gguf_set_val_u32(ctx_out, "general.file_type", itype);
+
+    auto fout = std::ofstream(fname_out, std::ios::binary);
+
+    const int n_tensors = gguf_get_n_tensors(ctx_src);
+
+    for (int i = 0; i < n_tensors; ++i) {
+        const char * name = gguf_get_tensor_name(ctx_src, i);
+        struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
+        gguf_add_tensor(ctx_out, cur);
+    }
+
+    const size_t meta_size = gguf_get_meta_size(ctx_out);
+    for (size_t i = 0; i < meta_size; ++i) {
+        fout.put(0);
+    }
+
+    // regexes of tensor names to be quantized
+    const std::vector k_names = {
+        ".*weight",
+    };
+
+    std::vector work(512);
+    std::vector conv_buf(512);
+    size_t total_size_org = 0;
+    size_t total_size_new = 0;
+
+    for (int i = 0; i < n_tensors; ++i) {
+        const std::string name = gguf_get_tensor_name(ctx_src, i);
+        struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str());
+
+        enum ggml_type new_type;
+        void * new_data;
+        size_t new_size;
+
+        bool quantize = false;
+        for (const auto & s : k_names) {
+            if (std::regex_match(name, std::regex(s))) {
+                quantize = true;
+                break;
+            }
+        }
+
+        // quantize only 2D tensors
+        quantize &= (ggml_n_dims(cur) == 2);
+
+        if (quantize) {
+            new_type = type;
+            if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) {
+                new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type
+                // LOG_ERR("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type));
+            }
+            const size_t n_elms = ggml_nelements(cur);
+            float * f32_data;
+
+            switch (cur->type) {
+            case GGML_TYPE_F32:
+                f32_data = (float *)cur->data;
+                break;
+            case GGML_TYPE_F16:
+                if (conv_buf.size() < n_elms) {
+                    conv_buf.resize(n_elms);
+                }
+                for (size_t j = 0; j < n_elms; ++j) {
+                    conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]);
+                }
+                f32_data = (float *)conv_buf.data();
+                break;
+            default:
+                LOG_ERR("Please use an input file in f32 or f16\n");
+                gguf_free(ctx_out);
+                return false;
+            }
+
+            if (work.size() < n_elms * 4) {
+                work.resize(n_elms * 4);
+            }
+            new_data = work.data();
+
+            new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr);
+        } else {
+            new_type = cur->type;
+            new_data = cur->data;
+            new_size = ggml_nbytes(cur);
+        }
+        const size_t orig_size = ggml_nbytes(cur);
+        total_size_org += orig_size;
+        total_size_new += new_size;
+        gguf_set_tensor_type(ctx_out, name.c_str(), new_type);
+        gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size);
+        fout.write((const char *)new_data, new_size);
+        size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size;
+        for (size_t j = 0; j < pad; ++j) {
+            fout.put(0);
+        }
+
+        LOG_INF("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize,
+               orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+    }
+
+    // go back to beginning of file and write the updated metadata
+    fout.seekp(0, std::ios::beg);
+    std::vector meta(meta_size);
+    gguf_get_meta_data(ctx_out, meta.data());
+    fout.write((const char *)meta.data(), meta_size);
+
+    fout.close();
+
+    clip_free(ctx_clip);
+    gguf_free(ctx_out);
+
+    {
+        LOG_INF("%s: original  size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0);
+        LOG_INF("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0);
+    }
+
+    return true;
+}
+
+int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
+    if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+        return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
+        return ctx->vision_model.mm_model_peg_0_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+        return ctx->vision_model.mm_2_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+        return ctx->vision_model.mm_3_b->ne[0];
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+        if (ctx->minicpmv_version == 2) {
+            return 4096;
+        }
+        else if (ctx->minicpmv_version == 3) {
+            return 3584;
+        }
+    }
+    if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+        return ctx->vision_model.mm_1_b->ne[0];
+    }
+
+    std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
+    throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
+}
+
+int clip_is_minicpmv(const struct clip_ctx * ctx) {
+    if (ctx->has_minicpmv_projector) {
+        return ctx->minicpmv_version;
+    }
+    return 0;
+}
+
+bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
+    return ctx->has_qwen2vl_merger;
+}
+
+
+bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
+    clip_image_f32 clip_img;
+    clip_img.buf.resize(h * w * 3);
+    for (int i = 0; i < h*w*3; i++)
+    {
+        clip_img.buf[i] = img[i];
+    }
+    clip_img.nx = w;
+    clip_img.ny = h;
+    clip_image_encode(ctx, n_threads, &clip_img, vec);
+    return true;
+}
diff --git a/llama/clip.h b/llama/clip.h
new file mode 100644
index 000000000..42f24bd6c
--- /dev/null
+++ b/llama/clip.h
@@ -0,0 +1,126 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef CLIP_H
+#define CLIP_H
+
+#include 
+#include 
+
+#ifdef LLAMA_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef LLAMA_BUILD
+#            define CLIP_API __declspec(dllexport)
+#        else
+#            define CLIP_API __declspec(dllimport)
+#        endif
+#    else
+#        define CLIP_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define CLIP_API
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct clip_ctx;
+
+struct clip_image_size {
+    int width;
+    int height;
+};
+
+struct clip_image_u8_batch {
+    struct clip_image_u8 * data;
+    size_t size;
+};
+
+struct clip_image_f32_batch {
+    struct clip_image_f32 * data;
+    size_t size;
+};
+
+CLIP_API struct clip_ctx * clip_model_load    (const char * fname, int verbosity);
+CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
+
+CLIP_API void clip_free(struct clip_ctx * ctx);
+
+CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
+CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
+
+CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
+CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
+CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
+
+// TODO: should be enum, not string
+CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
+
+CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
+
+CLIP_API int clip_n_patches        (const struct clip_ctx * ctx);
+CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
+CLIP_API int clip_n_mmproj_embd    (const struct clip_ctx * ctx);
+
+CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
+CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
+CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
+
+CLIP_API struct clip_image_size * clip_image_size_init();
+CLIP_API struct clip_image_u8  * clip_image_u8_init ();
+CLIP_API struct clip_image_f32 * clip_image_f32_init();
+
+CLIP_API void clip_image_u8_free (struct clip_image_u8  * img);
+CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
+CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
+CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+
+CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
+
+/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
+CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
+
+/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
+CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
+
+CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
+
+CLIP_API bool clip_image_encode      (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
+CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
+
+CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
+
+CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
+CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
+
+CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // CLIP_H
diff --git a/llama/common.cpp b/llama/common.cpp
new file mode 100644
index 000000000..132de88aa
--- /dev/null
+++ b/llama/common.cpp
@@ -0,0 +1,1995 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#if defined(_MSC_VER)
+#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
+#endif
+
+#include "common.h"
+#include "log.h"
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
+#include "json.hpp"
+#include "json-schema-to-grammar.h"
+#include "llama.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(__APPLE__) && defined(__MACH__)
+#include 
+#include 
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#   define NOMINMAX
+#endif
+#include 
+#include 
+#include 
+#include 
+#else
+#include 
+#include 
+#include 
+#endif
+#if defined(LLAMA_USE_CURL)
+#include 
+#include 
+#include 
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#if defined(LLAMA_USE_CURL)
+#ifdef __linux__
+#include 
+#elif defined(_WIN32)
+#   if !defined(PATH_MAX)
+#   define PATH_MAX MAX_PATH
+#   endif
+#else
+#include 
+#endif
+#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
+#endif // LLAMA_USE_CURL
+
+using json = nlohmann::ordered_json;
+
+//
+// CPU utils
+//
+
+int32_t cpu_get_num_physical_cores() {
+#ifdef __linux__
+    // enumerate the set of thread siblings, num entries is num cores
+    std::unordered_set siblings;
+    for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
+        std::ifstream thread_siblings("/sys/devices/system/cpu/cpu"
+            + std::to_string(cpu) + "/topology/thread_siblings");
+        if (!thread_siblings.is_open()) {
+            break; // no more cpus
+        }
+        std::string line;
+        if (std::getline(thread_siblings, line)) {
+            siblings.insert(line);
+        }
+    }
+    if (!siblings.empty()) {
+        return static_cast(siblings.size());
+    }
+#elif defined(__APPLE__) && defined(__MACH__)
+    int32_t num_physical_cores;
+    size_t len = sizeof(num_physical_cores);
+    int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
+    if (result == 0) {
+        return num_physical_cores;
+    }
+    result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
+    if (result == 0) {
+        return num_physical_cores;
+    }
+#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
+    // TODO: windows + arm64 + mingw64
+    unsigned int n_threads_win = std::thread::hardware_concurrency();
+    unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
+
+    DWORD buffer_size = 0;
+    if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
+        if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
+            return default_threads;
+        }
+    }
+
+    std::vector buffer(buffer_size);
+    if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast(buffer.data()), &buffer_size)) {
+        return default_threads;
+    }
+
+    int32_t num_physical_cores = 0;
+    PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast(buffer.data());
+    while (buffer_size > 0) {
+        if (info->Relationship == RelationProcessorCore) {
+            num_physical_cores += info->Processor.GroupCount;
+        }
+        buffer_size -= info->Size;
+        info = reinterpret_cast(reinterpret_cast(info) + info->Size);
+    }
+
+    return num_physical_cores > 0 ? num_physical_cores : default_threads;
+#endif
+    unsigned int n_threads = std::thread::hardware_concurrency();
+    return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
+}
+
+#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
+#include 
+
+static void cpuid(unsigned leaf, unsigned subleaf,
+                  unsigned *eax, unsigned *ebx, unsigned *ecx, unsigned *edx) {
+    __asm__("movq\t%%rbx,%%rsi\n\t"
+            "cpuid\n\t"
+            "xchgq\t%%rbx,%%rsi"
+            : "=a"(*eax), "=S"(*ebx), "=c"(*ecx), "=d"(*edx)
+            : "0"(leaf), "2"(subleaf));
+}
+
+static int pin_cpu(int cpu) {
+    cpu_set_t mask;
+    CPU_ZERO(&mask);
+    CPU_SET(cpu, &mask);
+    return pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask);
+}
+
+static bool is_hybrid_cpu(void) {
+    unsigned eax, ebx, ecx, edx;
+    cpuid(7, 0, &eax, &ebx, &ecx, &edx);
+    return !!(edx & (1u << 15));
+}
+
+static bool is_running_on_efficiency_core(void) {
+    unsigned eax, ebx, ecx, edx;
+    cpuid(0x1a, 0, &eax, &ebx, &ecx, &edx);
+    int intel_atom = 0x20;
+    int core_type = (eax & 0xff000000u) >> 24;
+    return core_type == intel_atom;
+}
+
+static int cpu_count_math_cpus(int n_cpu) {
+    int result = 0;
+    for (int cpu = 0; cpu < n_cpu; ++cpu) {
+        if (pin_cpu(cpu)) {
+            return -1;
+        }
+        if (is_running_on_efficiency_core()) {
+            continue; // efficiency cores harm lockstep threading
+        }
+        ++cpu; // hyperthreading isn't useful for linear algebra
+        ++result;
+    }
+    return result;
+}
+
+#endif // __x86_64__ && __linux__
+
+/**
+ * Returns number of CPUs on system that are useful for math.
+ */
+int32_t cpu_get_num_math() {
+#if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__)
+    int n_cpu = sysconf(_SC_NPROCESSORS_ONLN);
+    if (n_cpu < 1) {
+        return cpu_get_num_physical_cores();
+    }
+    if (is_hybrid_cpu()) {
+        cpu_set_t affinity;
+        if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) {
+            int result = cpu_count_math_cpus(n_cpu);
+            pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity);
+            if (result > 0) {
+                return result;
+            }
+        }
+    }
+#endif
+    return cpu_get_num_physical_cores();
+}
+
+// Helper for setting process priority
+
+#if defined(_WIN32)
+
+bool set_process_priority(enum ggml_sched_priority prio) {
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        return true;
+    }
+
+    DWORD p = NORMAL_PRIORITY_CLASS;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   p = NORMAL_PRIORITY_CLASS;       break;
+        case GGML_SCHED_PRIO_MEDIUM:   p = ABOVE_NORMAL_PRIORITY_CLASS; break;
+        case GGML_SCHED_PRIO_HIGH:     p = HIGH_PRIORITY_CLASS;         break;
+        case GGML_SCHED_PRIO_REALTIME: p = REALTIME_PRIORITY_CLASS;     break;
+    }
+
+    if (!SetPriorityClass(GetCurrentProcess(), p)) {
+        LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
+        return false;
+    }
+
+    return true;
+}
+
+#else // MacOS and POSIX
+#include 
+#include 
+
+bool set_process_priority(enum ggml_sched_priority prio) {
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        return true;
+    }
+
+    int p = 0;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   p =  0;  break;
+        case GGML_SCHED_PRIO_MEDIUM:   p = -5;  break;
+        case GGML_SCHED_PRIO_HIGH:     p = -10; break;
+        case GGML_SCHED_PRIO_REALTIME: p = -20; break;
+    }
+
+    if (!setpriority(PRIO_PROCESS, 0, p)) {
+        LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
+        return false;
+    }
+    return true;
+}
+
+#endif
+
+//
+// CLI argument parsing
+//
+
+
+void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) {
+    int32_t n_set = 0;
+
+    if (cpuparams.n_threads < 0) {
+        // Assuming everything about cpuparams is invalid
+        if (role_model != nullptr) {
+            cpuparams = *role_model;
+        } else {
+            cpuparams.n_threads = cpu_get_num_math();
+        }
+    }
+
+    for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
+        if (cpuparams.cpumask[i]) {
+            n_set++;
+        }
+    }
+
+    if (n_set && n_set < cpuparams.n_threads) {
+        // Not enough set bits, may experience performance issues.
+        LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
+    }
+}
+
+bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
+    size_t dash_loc = range.find('-');
+    if (dash_loc == std::string::npos) {
+        LOG_ERR("Format of CPU range is invalid! Expected []-[].\n");
+        return false;
+    }
+
+    size_t start_i;
+    size_t end_i;
+
+    if (dash_loc == 0) {
+        start_i = 0;
+    } else {
+        start_i = std::stoull(range.substr(0, dash_loc));
+        if (start_i >= GGML_MAX_N_THREADS) {
+            LOG_ERR("Start index out of bounds!\n");
+            return false;
+        }
+    }
+
+    if (dash_loc == range.length() - 1) {
+        end_i = GGML_MAX_N_THREADS - 1;
+    } else {
+        end_i = std::stoull(range.substr(dash_loc + 1));
+        if (end_i >= GGML_MAX_N_THREADS) {
+            LOG_ERR("End index out of bounds!\n");
+            return false;
+        }
+    }
+
+    for (size_t i = start_i; i <= end_i; i++) {
+        boolmask[i] = true;
+    }
+
+    return true;
+}
+
+bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREADS]) {
+    // Discard potential 0x prefix
+    size_t start_i = 0;
+    if (mask.length() >= 2 && mask.substr(0, 2) == "0x") {
+        start_i = 2;
+    }
+
+    size_t num_digits = mask.length() - start_i;
+    if (num_digits > 128) num_digits = 128;
+
+    size_t end_i = num_digits + start_i;
+
+    for (size_t i = start_i, n = (num_digits*4 - 1); i < end_i; i++, n-=4) {
+        char c = mask.at(i);
+        int8_t id = c;
+
+        if ((c >= '0' && c <= '9')) {
+            id -= '0';
+        } else if (c >= 'a' && c <= 'f') {
+            id -= 'a' - 10;
+        } else if (c >= 'A' && c <= 'F') {
+            id -= 'A' - 10;
+        } else {
+            LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
+            return false;
+        }
+
+        boolmask[  n  ] = boolmask[  n  ] || ((id & 8) != 0);
+        boolmask[n - 1] = boolmask[n - 1] || ((id & 4) != 0);
+        boolmask[n - 2] = boolmask[n - 2] || ((id & 2) != 0);
+        boolmask[n - 3] = boolmask[n - 3] || ((id & 1) != 0);
+    }
+
+    return true;
+}
+
+void common_init() {
+    llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
+        if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
+            common_log_add(common_log_main(), level, "%s", text);
+        }
+    }, NULL);
+
+#ifdef NDEBUG
+    const char * build_type = "";
+#else
+    const char * build_type = " (debug)";
+#endif
+
+    LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
+}
+
+std::string common_params_get_system_info(const common_params & params) {
+    std::ostringstream os;
+
+    os << "system_info: n_threads = " << params.cpuparams.n_threads;
+    if (params.cpuparams_batch.n_threads != -1) {
+        os << " (n_threads_batch = " << params.cpuparams_batch.n_threads << ")";
+    }
+#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
+    // TODO: windows + arm64 + mingw64
+    DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
+    os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
+#else
+    os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
+#endif
+
+    return os.str();
+}
+
+//
+// String utils
+//
+
+std::string string_format(const char * fmt, ...) {
+    va_list ap;
+    va_list ap2;
+    va_start(ap, fmt);
+    va_copy(ap2, ap);
+    int size = vsnprintf(NULL, 0, fmt, ap);
+    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
+    std::vector buf(size + 1);
+    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
+    GGML_ASSERT(size2 == size);
+    va_end(ap2);
+    va_end(ap);
+    return std::string(buf.data(), size);
+}
+
+std::string string_strip(const std::string & str) {
+    size_t start = 0;
+    size_t end = str.size();
+    while (start < end && std::isspace(str[start])) {
+        start++;
+    }
+    while (end > start && std::isspace(str[end - 1])) {
+        end--;
+    }
+    return str.substr(start, end - start);
+}
+
+std::string string_get_sortable_timestamp() {
+    using clock = std::chrono::system_clock;
+
+    const clock::time_point current_time = clock::now();
+    const time_t as_time_t = clock::to_time_t(current_time);
+    char timestamp_no_ns[100];
+    std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
+
+    const int64_t ns = std::chrono::duration_cast(
+        current_time.time_since_epoch() % 1000000000).count();
+    char timestamp_ns[11];
+    snprintf(timestamp_ns, 11, "%09" PRId64, ns);
+
+    return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
+}
+
+void string_replace_all(std::string & s, const std::string & search, const std::string & replace) {
+    if (search.empty()) {
+        return;
+    }
+    std::string builder;
+    builder.reserve(s.length());
+    size_t pos = 0;
+    size_t last_pos = 0;
+    while ((pos = s.find(search, last_pos)) != std::string::npos) {
+        builder.append(s, last_pos, pos - last_pos);
+        builder.append(replace);
+        last_pos = pos + search.length();
+    }
+    builder.append(s, last_pos, std::string::npos);
+    s = std::move(builder);
+}
+
+std::string string_from(bool value) {
+    return value ? "true" : "false";
+}
+
+std::string string_from(const std::vector & values) {
+    std::stringstream buf;
+
+    buf << "[ ";
+    bool first = true;
+    for (auto e : values) {
+        if (first) {
+            first = false;
+        } else {
+            buf << ", ";
+        }
+        buf << std::to_string(e);
+    }
+    buf << " ]";
+
+    return buf.str();
+}
+
+std::string string_from(const struct llama_context * ctx, const std::vector & tokens) {
+    std::stringstream buf;
+
+    buf << "[ ";
+
+    bool first = true;
+    for (const auto & token : tokens) {
+        if (!first) {
+            buf << ", ";
+        } else {
+            first = false;
+        }
+
+        auto detokenized = common_token_to_piece(ctx, token);
+
+        detokenized.erase(
+            std::remove_if(
+                detokenized.begin(),
+                detokenized.end(),
+                [](const unsigned char c) { return !std::isprint(c); }),
+            detokenized.end());
+
+        buf << "'" << detokenized << "'"
+            << ":" << std::to_string(token);
+    }
+
+    buf << " ]";
+
+    return buf.str();
+}
+
+std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
+    std::stringstream buf;
+
+    buf << "[ ";
+
+    bool first = true;
+    for (int i = 0; i < batch.n_tokens; ++i) {
+        if (!first) {
+            buf << ", ";
+        } else {
+            first = false;
+        }
+
+        auto detokenized = common_token_to_piece(ctx, batch.token[i]);
+
+        detokenized.erase(
+                std::remove_if(
+                    detokenized.begin(),
+                    detokenized.end(),
+                    [](const unsigned char c) { return !std::isprint(c); }),
+                detokenized.end());
+
+        buf << "\n"          << std::to_string(i)
+            << ", token '"   << detokenized << "'"
+            << ", pos "      << std::to_string(batch.pos[i])
+            << ", n_seq_id " << std::to_string(batch.n_seq_id[i])
+            << ", seq_id "   << std::to_string(batch.seq_id[i][0])
+            << ", logits "   << std::to_string(batch.logits[i]);
+    }
+
+    buf << " ]";
+
+    return buf.str();
+}
+
+void string_process_escapes(std::string & input) {
+    std::size_t input_len = input.length();
+    std::size_t output_idx = 0;
+
+    for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
+        if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
+            switch (input[++input_idx]) {
+                case 'n':  input[output_idx++] = '\n'; break;
+                case 'r':  input[output_idx++] = '\r'; break;
+                case 't':  input[output_idx++] = '\t'; break;
+                case '\'': input[output_idx++] = '\''; break;
+                case '\"': input[output_idx++] = '\"'; break;
+                case '\\': input[output_idx++] = '\\'; break;
+                case 'x':
+                    // Handle \x12, etc
+                    if (input_idx + 2 < input_len) {
+                        const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 };
+                        char *err_p = nullptr;
+                        const long val = std::strtol(x, &err_p, 16);
+                        if (err_p == x + 2) {
+                            input_idx += 2;
+                            input[output_idx++] = char(val);
+                            break;
+                        }
+                    }
+                    // fall through
+                default:   input[output_idx++] = '\\';
+                           input[output_idx++] = input[input_idx]; break;
+            }
+        } else {
+            input[output_idx++] = input[input_idx];
+        }
+    }
+
+    input.resize(output_idx);
+}
+
+bool string_parse_kv_override(const char * data, std::vector & overrides) {
+    const char * sep = strchr(data, '=');
+    if (sep == nullptr || sep - data >= 128) {
+        LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
+        return false;
+    }
+    llama_model_kv_override kvo;
+    std::strncpy(kvo.key, data, sep - data);
+    kvo.key[sep - data] = 0;
+    sep++;
+    if (strncmp(sep, "int:", 4) == 0) {
+        sep += 4;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
+        kvo.val_i64 = std::atol(sep);
+    } else if (strncmp(sep, "float:", 6) == 0) {
+        sep += 6;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
+        kvo.val_f64 = std::atof(sep);
+    } else if (strncmp(sep, "bool:", 5) == 0) {
+        sep += 5;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
+        if (std::strcmp(sep, "true") == 0) {
+            kvo.val_bool = true;
+        } else if (std::strcmp(sep, "false") == 0) {
+            kvo.val_bool = false;
+        } else {
+            LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
+            return false;
+        }
+    } else if (strncmp(sep, "str:", 4) == 0) {
+        sep += 4;
+        kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
+        if (strlen(sep) > 127) {
+            LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
+            return false;
+        }
+        strncpy(kvo.val_str, sep, 127);
+        kvo.val_str[127] = '\0';
+    } else {
+        LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
+        return false;
+    }
+    overrides.emplace_back(std::move(kvo));
+    return true;
+}
+
+//
+// Filesystem utils
+//
+
+// Validate if a filename is safe to use
+// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
+bool fs_validate_filename(const std::string & filename) {
+    if (!filename.length()) {
+        // Empty filename invalid
+        return false;
+    }
+    if (filename.length() > 255) {
+        // Limit at common largest possible filename on Linux filesystems
+        // to avoid unnecessary further validation
+        // (On systems with smaller limits it will be caught by the OS)
+        return false;
+    }
+
+    std::u32string filename_utf32;
+    try {
+#if defined(__clang__)
+        // disable C++17 deprecation warning for std::codecvt_utf8
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+        std::wstring_convert, char32_t> converter;
+
+#if defined(__clang__)
+#    pragma clang diagnostic pop
+#endif
+
+        filename_utf32 = converter.from_bytes(filename);
+
+        // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
+        // or invalid encodings were encountered. Reject such attempts
+        std::string filename_reencoded = converter.to_bytes(filename_utf32);
+        if (filename_reencoded != filename) {
+            return false;
+        }
+    } catch (const std::exception &) {
+        return false;
+    }
+
+    // Check for forbidden codepoints:
+    // - Control characters
+    // - Unicode equivalents of illegal characters
+    // - UTF-16 surrogate pairs
+    // - UTF-8 replacement character
+    // - Byte order mark (BOM)
+    // - Illegal characters: / \ : * ? " < > |
+    for (char32_t c : filename_utf32) {
+        if (c <= 0x1F // Control characters (C0)
+            || c == 0x7F // Control characters (DEL)
+            || (c >= 0x80 && c <= 0x9F) // Control characters (C1)
+            || c == 0xFF0E // Fullwidth Full Stop (period equivalent)
+            || c == 0x2215 // Division Slash (forward slash equivalent)
+            || c == 0x2216 // Set Minus (backslash equivalent)
+            || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
+            || c == 0xFFFD // Replacement Character (UTF-8)
+            || c == 0xFEFF // Byte Order Mark (BOM)
+            || c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
+            || c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
+            return false;
+        }
+    }
+
+    // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
+    // Unicode and other whitespace is not affected, only 0x20 space
+    if (filename.front() == ' ' || filename.back() == ' ' || filename.back() == '.') {
+        return false;
+    }
+
+    // Reject any ".." (currently stricter than necessary, it should be fine to just check for == ".." instead)
+    if (filename.find("..") != std::string::npos) {
+        return false;
+    }
+
+    // Reject "."
+    if (filename == ".") {
+        return false;
+    }
+
+    return true;
+}
+
+// returns true if successful, false otherwise
+bool fs_create_directory_with_parents(const std::string & path) {
+#ifdef _WIN32
+    std::wstring_convert> converter;
+    std::wstring wpath = converter.from_bytes(path);
+
+    // if the path already exists, check whether it's a directory
+    const DWORD attributes = GetFileAttributesW(wpath.c_str());
+    if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
+        return true;
+    }
+
+    size_t pos_slash = 0;
+
+    // process path from front to back, procedurally creating directories
+    while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
+        const std::wstring subpath = wpath.substr(0, pos_slash);
+        const wchar_t * test = subpath.c_str();
+
+        const bool success = CreateDirectoryW(test, NULL);
+        if (!success) {
+            const DWORD error = GetLastError();
+
+            // if the path already exists, ensure that it's a directory
+            if (error == ERROR_ALREADY_EXISTS) {
+                const DWORD attributes = GetFileAttributesW(subpath.c_str());
+                if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
+                    return false;
+                }
+            } else {
+                return false;
+            }
+        }
+
+        pos_slash += 1;
+    }
+
+    return true;
+#else
+    // if the path already exists, check whether it's a directory
+    struct stat info;
+    if (stat(path.c_str(), &info) == 0) {
+        return S_ISDIR(info.st_mode);
+    }
+
+    size_t pos_slash = 1; // skip leading slashes for directory creation
+
+    // process path from front to back, procedurally creating directories
+    while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
+        const std::string subpath = path.substr(0, pos_slash);
+        struct stat info;
+
+        // if the path already exists, ensure that it's a directory
+        if (stat(subpath.c_str(), &info) == 0) {
+            if (!S_ISDIR(info.st_mode)) {
+                return false;
+            }
+        } else {
+            // create parent directories
+            const int ret = mkdir(subpath.c_str(), 0755);
+            if (ret != 0) {
+                return false;
+            }
+        }
+
+        pos_slash += 1;
+    }
+
+    return true;
+#endif // _WIN32
+}
+
+std::string fs_get_cache_directory() {
+    std::string cache_directory = "";
+    auto ensure_trailing_slash = [](std::string p) {
+        // Make sure to add trailing slash
+        if (p.back() != DIRECTORY_SEPARATOR) {
+            p += DIRECTORY_SEPARATOR;
+        }
+        return p;
+    };
+    if (getenv("LLAMA_CACHE")) {
+        cache_directory = std::getenv("LLAMA_CACHE");
+    } else {
+#ifdef __linux__
+        if (std::getenv("XDG_CACHE_HOME")) {
+            cache_directory = std::getenv("XDG_CACHE_HOME");
+        } else {
+            cache_directory = std::getenv("HOME") + std::string("/.cache/");
+        }
+#elif defined(__APPLE__)
+        cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
+#elif defined(_WIN32)
+        cache_directory = std::getenv("LOCALAPPDATA");
+#endif // __linux__
+        cache_directory = ensure_trailing_slash(cache_directory);
+        cache_directory += "llama.cpp";
+    }
+    return ensure_trailing_slash(cache_directory);
+}
+
+std::string fs_get_cache_file(const std::string & filename) {
+    GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos);
+    std::string cache_directory = fs_get_cache_directory();
+    const bool success = fs_create_directory_with_parents(cache_directory);
+    if (!success) {
+        throw std::runtime_error("failed to create cache directory: " + cache_directory);
+    }
+    return cache_directory + filename;
+}
+
+
+//
+// Model utils
+//
+struct common_init_result common_init_from_params(common_params & params) {
+    common_init_result iparams;
+    auto mparams = common_model_params_to_llama(params);
+
+    llama_model * model = nullptr;
+
+    if (!params.hf_repo.empty() && !params.hf_file.empty()) {
+        model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
+    } else if (!params.model_url.empty()) {
+        model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
+    } else {
+        model = llama_load_model_from_file(params.model.c_str(), mparams);
+    }
+
+    if (model == NULL) {
+        LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
+        return iparams;
+    }
+
+    if (params.reranking) {
+        bool ok = true;
+
+        if (llama_token_bos(model) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: model does not have a  BOS token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (llama_token_eos(model) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (llama_token_sep(model) == LLAMA_TOKEN_NULL) {
+            LOG_WRN("%s: warning: model does not have a  SEP token, reranking will not work\n", __func__);
+            ok = false;
+        }
+
+        if (!ok) {
+            llama_free_model(model);
+
+            return iparams;
+        }
+    }
+
+    auto cparams = common_context_params_to_llama(params);
+
+    llama_context * lctx = llama_new_context_with_model(model, cparams);
+    if (lctx == NULL) {
+        LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
+        llama_free_model(model);
+        return iparams;
+    }
+
+    if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
+        LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
+        params.ctx_shift = false;
+    }
+
+    if (!params.control_vectors.empty()) {
+        if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
+        if (params.control_vector_layer_end   <= 0) params.control_vector_layer_end   = llama_n_layer(model);
+
+        const auto cvec = common_control_vector_load(params.control_vectors);
+        if (cvec.n_embd == -1) {
+            llama_free(lctx);
+            llama_free_model(model);
+
+            return iparams;
+        }
+
+        int err = llama_control_vector_apply(lctx,
+                                             cvec.data.data(),
+                                             cvec.data.size(),
+                                             cvec.n_embd,
+                                             params.control_vector_layer_start,
+                                             params.control_vector_layer_end);
+        if (err) {
+            llama_free(lctx);
+            llama_free_model(model);
+
+            return iparams;
+        }
+    }
+
+    // load and optionally apply lora adapters
+    for (auto & la : params.lora_adapters) {
+        llama_lora_adapter_ptr lora;
+        lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
+        if (lora == nullptr) {
+            LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
+            llama_free(lctx);
+            llama_free_model(model);
+            return iparams;
+        }
+
+        la.ptr = lora.get();
+        iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
+    }
+
+    if (!params.lora_init_without_apply) {
+        common_lora_adapters_apply(lctx, params.lora_adapters);
+    }
+
+    if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
+        LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
+        params.sampling.ignore_eos = false;
+    }
+
+    if (params.sampling.ignore_eos) {
+        for (llama_token i = 0; i < llama_n_vocab(model); i++) {
+            if (llama_token_is_eog(model, i)) {
+                LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
+                params.sampling.logit_bias.push_back({i, -INFINITY});
+            }
+        }
+    }
+
+    if (params.sampling.penalty_last_n == -1) {
+        LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+        params.sampling.penalty_last_n = llama_n_ctx(lctx);
+    }
+
+    if (params.sampling.dry_penalty_last_n == -1) {
+        LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+        params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
+    }
+
+    if (params.warmup) {
+        LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
+
+        std::vector tmp;
+        llama_token bos = llama_token_bos(model);
+        llama_token eos = llama_token_eos(model);
+        // some models (e.g. T5) don't have a BOS token
+        if (bos != LLAMA_TOKEN_NULL) {
+            tmp.push_back(bos);
+        }
+        if (eos != LLAMA_TOKEN_NULL) {
+            tmp.push_back(eos);
+        }
+        if (tmp.empty()) {
+            tmp.push_back(0);
+        }
+
+        if (llama_model_has_encoder(model)) {
+            llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
+            llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+            if (decoder_start_token_id == -1) {
+                decoder_start_token_id = bos;
+            }
+            tmp.clear();
+            tmp.push_back(decoder_start_token_id);
+        }
+        if (llama_model_has_decoder(model)) {
+            llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
+        }
+        llama_kv_cache_clear(lctx);
+        llama_synchronize(lctx);
+        llama_perf_context_reset(lctx);
+    }
+
+    iparams.model.reset(model);
+    iparams.context.reset(lctx);
+
+    return iparams;
+}
+
+void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora) {
+    llama_lora_adapter_clear(ctx);
+    for (auto & la : lora) {
+        if (la.scale != 0.0f) {
+            llama_lora_adapter_set(ctx, la.ptr, la.scale);
+        }
+    }
+}
+
+struct llama_model_params common_model_params_to_llama(common_params & params) {
+    auto mparams = llama_model_default_params();
+
+    if (!params.devices.empty()) {
+        mparams.devices = params.devices.data();
+    }
+    if (params.n_gpu_layers != -1) {
+        mparams.n_gpu_layers = params.n_gpu_layers;
+    }
+    mparams.rpc_servers     = params.rpc_servers.c_str();
+    mparams.main_gpu        = params.main_gpu;
+    mparams.split_mode      = params.split_mode;
+    mparams.tensor_split    = params.tensor_split;
+    mparams.use_mmap        = params.use_mmap;
+    mparams.use_mlock       = params.use_mlock;
+    mparams.check_tensors   = params.check_tensors;
+    if (params.kv_overrides.empty()) {
+        mparams.kv_overrides = NULL;
+    } else {
+        GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
+        mparams.kv_overrides = params.kv_overrides.data();
+    }
+
+    return mparams;
+}
+
+struct llama_context_params common_context_params_to_llama(const common_params & params) {
+    auto cparams = llama_context_default_params();
+
+    cparams.n_ctx             = params.n_ctx;
+    cparams.n_seq_max         = params.n_parallel;
+    cparams.n_batch           = params.n_batch;
+    cparams.n_ubatch          = params.n_ubatch;
+    cparams.n_threads         = params.cpuparams.n_threads;
+    cparams.n_threads_batch   = params.cpuparams_batch.n_threads == -1 ?
+                                params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
+    cparams.logits_all        = params.logits_all;
+    cparams.embeddings        = params.embedding;
+    cparams.rope_scaling_type = params.rope_scaling_type;
+    cparams.rope_freq_base    = params.rope_freq_base;
+    cparams.rope_freq_scale   = params.rope_freq_scale;
+    cparams.yarn_ext_factor   = params.yarn_ext_factor;
+    cparams.yarn_attn_factor  = params.yarn_attn_factor;
+    cparams.yarn_beta_fast    = params.yarn_beta_fast;
+    cparams.yarn_beta_slow    = params.yarn_beta_slow;
+    cparams.yarn_orig_ctx     = params.yarn_orig_ctx;
+    cparams.pooling_type      = params.pooling_type;
+    cparams.attention_type    = params.attention_type;
+    cparams.defrag_thold      = params.defrag_thold;
+    cparams.cb_eval           = params.cb_eval;
+    cparams.cb_eval_user_data = params.cb_eval_user_data;
+    cparams.offload_kqv       = !params.no_kv_offload;
+    cparams.flash_attn        = params.flash_attn;
+    cparams.no_perf           = params.no_perf;
+
+    if (params.reranking) {
+        cparams.embeddings    = true;
+        cparams.pooling_type  = LLAMA_POOLING_TYPE_RANK;
+    }
+
+    cparams.type_k = params.cache_type_k;
+    cparams.type_v = params.cache_type_v;
+
+    return cparams;
+}
+
+struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params) {
+    struct ggml_threadpool_params tpp;
+
+    ggml_threadpool_params_init(&tpp, params.n_threads); // setup the defaults
+
+    if (params.mask_valid) {
+        std::memcpy(&tpp.cpumask, ¶ms.cpumask, GGML_MAX_N_THREADS);
+    }
+
+    tpp.prio       = params.priority;
+    tpp.poll       = params.poll;
+    tpp.strict_cpu = params.strict_cpu;
+
+    return tpp;
+}
+
+#ifdef LLAMA_USE_CURL
+
+#define CURL_MAX_RETRY 3
+#define CURL_RETRY_DELAY_SECONDS 2
+
+static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
+    int remaining_attempts = max_attempts;
+
+    while (remaining_attempts > 0) {
+        LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
+
+        CURLcode res = curl_easy_perform(curl);
+        if (res == CURLE_OK) {
+            return true;
+        }
+
+        int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
+        LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
+
+        remaining_attempts--;
+        std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
+    }
+
+    LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
+
+    return false;
+}
+
+static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
+    // Initialize libcurl
+    std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup);
+    if (!curl) {
+        LOG_ERR("%s: error initializing libcurl\n", __func__);
+        return false;
+    }
+
+    bool force_download = false;
+
+    // Set the URL, allow to follow http redirection
+    curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
+    curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
+
+    // Check if hf-token or bearer-token was specified
+    if (!hf_token.empty()) {
+      std::string auth_header = "Authorization: Bearer ";
+      auth_header += hf_token.c_str();
+      struct curl_slist *http_headers = NULL;
+      http_headers = curl_slist_append(http_headers, auth_header.c_str());
+      curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers);
+    }
+
+#if defined(_WIN32)
+    // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
+    //   operating system. Currently implemented under MS-Windows.
+    curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
+#endif
+
+    // Check if the file already exists locally
+    auto file_exists = std::filesystem::exists(path);
+
+    // If the file exists, check its JSON metadata companion file.
+    std::string metadata_path = path + ".json";
+    nlohmann::json metadata;
+    std::string etag;
+    std::string last_modified;
+
+    if (file_exists) {
+        // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
+        std::ifstream metadata_in(metadata_path);
+        if (metadata_in.good()) {
+            try {
+                metadata_in >> metadata;
+                LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
+                if (metadata.contains("url") && metadata.at("url").is_string()) {
+                    auto previous_url = metadata.at("url").get();
+                    if (previous_url != url) {
+                        LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
+                        return false;
+                    }
+                }
+                if (metadata.contains("etag") && metadata.at("etag").is_string()) {
+                    etag = metadata.at("etag");
+                }
+                if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
+                    last_modified = metadata.at("lastModified");
+                }
+            } catch (const nlohmann::json::exception & e) {
+            LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
+                return false;
+            }
+        }
+    } else {
+        LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
+    }
+
+    // Send a HEAD request to retrieve the etag and last-modified headers
+    struct common_load_model_from_url_headers {
+        std::string etag;
+        std::string last_modified;
+    };
+
+    common_load_model_from_url_headers headers;
+
+    {
+        typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
+        auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
+            common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
+
+            static std::regex header_regex("([^:]+): (.*)\r\n");
+            static std::regex etag_regex("ETag", std::regex_constants::icase);
+            static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
+
+            std::string header(buffer, n_items);
+            std::smatch match;
+            if (std::regex_match(header, match, header_regex)) {
+                const std::string & key = match[1];
+                const std::string & value = match[2];
+                if (std::regex_match(key, match, etag_regex)) {
+                    headers->etag = value;
+                } else if (std::regex_match(key, match, last_modified_regex)) {
+                    headers->last_modified = value;
+                }
+            }
+            return n_items;
+        };
+
+        curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
+        curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
+        curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback));
+        curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
+
+        bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
+        if (!was_perform_successful) {
+            return false;
+        }
+
+        long http_code = 0;
+        curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
+        if (http_code != 200) {
+            // HEAD not supported, we don't know if the file has changed
+            // force trigger downloading
+            force_download = true;
+            LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
+        }
+    }
+
+    bool should_download = !file_exists || force_download;
+    if (!should_download) {
+        if (!etag.empty() && etag != headers.etag) {
+            LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
+            should_download = true;
+        } else if (!last_modified.empty() && last_modified != headers.last_modified) {
+            LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
+            should_download = true;
+        }
+    }
+    if (should_download) {
+        std::string path_temporary = path + ".downloadInProgress";
+        if (file_exists) {
+            LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
+            if (remove(path.c_str()) != 0) {
+                LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
+                return false;
+            }
+        }
+
+        // Set the output file
+
+        struct FILE_deleter {
+            void operator()(FILE * f) const {
+                fclose(f);
+            }
+        };
+
+        std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb"));
+        if (!outfile) {
+            LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
+            return false;
+        }
+
+        typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
+        auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
+            return fwrite(data, size, nmemb, (FILE *)fd);
+        };
+        curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
+        curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback));
+        curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
+
+        //  display download progress
+        curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
+
+        // helper function to hide password in URL
+        auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
+            std::size_t protocol_pos = url.find("://");
+            if (protocol_pos == std::string::npos) {
+                return url;  // Malformed URL
+            }
+
+            std::size_t at_pos = url.find('@', protocol_pos + 3);
+            if (at_pos == std::string::npos) {
+                return url;  // No password in URL
+            }
+
+            return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
+        };
+
+        // start the download
+        LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
+            llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
+        bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
+        if (!was_perform_successful) {
+            return false;
+        }
+
+        long http_code = 0;
+        curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
+        if (http_code < 200 || http_code >= 400) {
+            LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
+            return false;
+        }
+
+        // Causes file to be closed explicitly here before we rename it.
+        outfile.reset();
+
+        // Write the updated JSON metadata file.
+        metadata.update({
+            {"url", url},
+            {"etag", headers.etag},
+            {"lastModified", headers.last_modified}
+        });
+        std::ofstream(metadata_path) << metadata.dump(4);
+        LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
+
+        if (rename(path_temporary.c_str(), path.c_str()) != 0) {
+            LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
+            return false;
+        }
+    }
+
+    return true;
+}
+
+struct llama_model * common_load_model_from_url(
+        const std::string & model_url,
+        const std::string & local_path,
+        const std::string & hf_token,
+        const struct llama_model_params & params) {
+    // Basic validation of the model_url
+    if (model_url.empty()) {
+        LOG_ERR("%s: invalid model_url\n", __func__);
+        return NULL;
+    }
+
+    if (!common_download_file(model_url, local_path, hf_token)) {
+        return NULL;
+    }
+
+    // check for additional GGUFs split to download
+    int n_split = 0;
+    {
+        struct gguf_init_params gguf_params = {
+            /*.no_alloc = */ true,
+            /*.ctx      = */ NULL,
+        };
+        auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
+        if (!ctx_gguf) {
+            LOG_ERR("\n%s:  failed to load input GGUF from %s\n", __func__, local_path.c_str());
+            return NULL;
+        }
+
+        auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
+        if (key_n_split >= 0) {
+            n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
+        }
+
+        gguf_free(ctx_gguf);
+    }
+
+    if (n_split > 1) {
+        char split_prefix[PATH_MAX] = {0};
+        char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
+
+        // Verify the first split file format
+        // and extract split URL and PATH prefixes
+        {
+            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
+                LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
+                return NULL;
+            }
+
+            if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
+                LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
+                return NULL;
+            }
+        }
+
+        // Prepare download in parallel
+        std::vector> futures_download;
+        for (int idx = 1; idx < n_split; idx++) {
+            futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
+                char split_path[PATH_MAX] = {0};
+                llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
+
+                char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
+                llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
+
+                return common_download_file(split_url, split_path, hf_token);
+            }, idx));
+        }
+
+        // Wait for all downloads to complete
+        for (auto & f : futures_download) {
+            if (!f.get()) {
+                return NULL;
+            }
+        }
+    }
+
+    return llama_load_model_from_file(local_path.c_str(), params);
+}
+
+struct llama_model * common_load_model_from_hf(
+        const std::string & repo,
+        const std::string & remote_path,
+        const std::string & local_path,
+        const std::string & hf_token,
+        const struct llama_model_params & params) {
+    // construct hugging face model url:
+    //
+    //  --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
+    //    https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
+    //
+    //  --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
+    //    https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
+    //
+
+    std::string model_url = "https://huggingface.co/";
+    model_url += repo;
+    model_url += "/resolve/main/";
+    model_url += remote_path;
+
+    return common_load_model_from_url(model_url, local_path, hf_token, params);
+}
+
+#else
+
+struct llama_model * common_load_model_from_url(
+        const std::string & /*model_url*/,
+        const std::string & /*local_path*/,
+        const std::string & /*hf_token*/,
+        const struct llama_model_params & /*params*/) {
+    LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
+    return nullptr;
+}
+
+struct llama_model * common_load_model_from_hf(
+        const std::string & /*repo*/,
+        const std::string & /*remote_path*/,
+        const std::string & /*local_path*/,
+        const std::string & /*hf_token*/,
+        const struct llama_model_params & /*params*/) {
+    LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
+    return nullptr;
+}
+
+#endif // LLAMA_USE_CURL
+
+//
+// Batch utils
+//
+
+void common_batch_clear(struct llama_batch & batch) {
+    batch.n_tokens = 0;
+}
+
+void common_batch_add(
+                 struct llama_batch & batch,
+                        llama_token   id,
+                          llama_pos   pos,
+    const std::vector & seq_ids,
+                               bool   logits) {
+    GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
+
+    batch.token   [batch.n_tokens] = id;
+    batch.pos     [batch.n_tokens] = pos;
+    batch.n_seq_id[batch.n_tokens] = seq_ids.size();
+    for (size_t i = 0; i < seq_ids.size(); ++i) {
+        batch.seq_id[batch.n_tokens][i] = seq_ids[i];
+    }
+    batch.logits  [batch.n_tokens] = logits;
+
+    batch.n_tokens++;
+}
+
+//
+// Token utils
+//
+
+size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
+    size_t i;
+    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+
+    return i;
+}
+
+size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
+    // check for empty sequences
+    if (a.empty() || b.empty()) {
+        return 0;
+    }
+
+    // get the lengths of the input sequences
+    size_t a_len = a.size();
+    size_t b_len = b.size();
+
+    // initialize the maximum length of the longest common subsequence (LCS)
+    size_t max_length = 0;
+
+    // use two rows instead of a 2D matrix to optimize space
+    std::vector prev_row(b_len + 1, 0);
+    std::vector curr_row(b_len + 1, 0);
+
+    // iterate through the elements of a
+    for (size_t i = 1; i <= a_len; i++) {
+        // iterate through the elements of b
+        for (size_t j = 1; j <= b_len; j++) {
+            // if elements at the current positions match
+            if (a[i - 1] == b[j - 1]) {
+                // if it's the first element of either sequences, set LCS length to 1
+                if (i == 1 || j == 1) {
+                    curr_row[j] = 1;
+                } else {
+                    // increment LCS length by 1 compared to the previous element
+                    curr_row[j] = prev_row[j - 1] + 1;
+                }
+
+                // update max_length if necessary
+                if (curr_row[j] > max_length) {
+                    max_length = curr_row[j];
+                }
+            } else {
+                // reset LCS length if elements don't match
+                curr_row[j] = 0;
+            }
+        }
+
+        // update the previous row for the next iteration
+        prev_row = curr_row;
+    }
+
+    // return the maximum length of the LCS
+    return max_length;
+}
+
+//
+// Vocab utils
+//
+
+std::vector common_tokenize(
+  const struct llama_context * ctx,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special) {
+    return common_tokenize(llama_get_model(ctx), text, add_special, parse_special);
+}
+
+std::vector common_tokenize(
+    const struct llama_model * model,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special) {
+    // upper limit for the number of tokens
+    int n_tokens = text.length() + 2 * add_special;
+    std::vector result(n_tokens);
+    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+    if (n_tokens < 0) {
+        result.resize(-n_tokens);
+        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
+        GGML_ASSERT(check == -n_tokens);
+    } else {
+        result.resize(n_tokens);
+    }
+    return result;
+}
+
+std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n'
+    const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
+    }
+
+    return piece;
+}
+
+std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
+    }
+
+    text.resize(n_chars);
+
+    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+    return text;
+}
+
+//
+// Chat template utils
+//
+
+std::string common_get_builtin_chat_template(const struct llama_model * model) {
+    static const char * template_key = "tokenizer.chat_template";
+    // call with NULL buffer to get the total size of the string
+    int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
+    if (res > 0) {
+        std::vector model_template(res + 1, 0);
+        llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
+        return std::string(model_template.data(), model_template.size() - 1);
+    }
+    return "";
+}
+
+bool common_chat_verify_template(const std::string & tmpl) {
+    llama_chat_message chat[] = {{"user", "test"}};
+    int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
+    return res >= 0;
+}
+
+std::string common_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector & msgs,
+        bool add_ass) {
+    int alloc_size = 0;
+    bool fallback = false; // indicate if we must fallback to default chatml
+    std::vector chat;
+    for (auto & msg : msgs) {
+        chat.push_back({msg.role.c_str(), msg.content.c_str()});
+        alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
+    }
+
+    const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
+    std::vector buf(alloc_size);
+
+    // run the first time to get the total output length
+    int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+
+    // error: chat template is not supported
+    if (res < 0) {
+        if (ptr_tmpl != nullptr) {
+            // if the custom "tmpl" is not supported, we throw an error
+            // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
+            throw std::runtime_error("this custom template is not supported");
+        } else {
+            // If the built-in template is not supported, we default to chatml
+            res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+            fallback = true;
+        }
+    }
+
+    // if it turns out that our buffer is too small, we resize it
+    if ((size_t) res > buf.size()) {
+        buf.resize(res);
+        res = llama_chat_apply_template(
+            fallback ? nullptr : model,
+            fallback ? "chatml" : ptr_tmpl,
+            chat.data(), chat.size(), add_ass, buf.data(), buf.size());
+    }
+
+    std::string formatted_chat(buf.data(), res);
+    return formatted_chat;
+}
+
+std::string common_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector & past_msg,
+        const common_chat_msg & new_msg,
+        bool add_ass) {
+    std::ostringstream ss;
+    auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false);
+    std::vector chat_new(past_msg);
+    // if the past_msg ends with a newline, we must preserve it in the formatted version
+    if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
+        ss << "\n";
+    };
+    // format chat with new_msg
+    chat_new.push_back(new_msg);
+    auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass);
+    // get the diff part
+    ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
+    return ss.str();
+}
+
+std::string common_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl) {
+    std::vector msgs = {
+        {"system",    "You are a helpful assistant"},
+        {"user",      "Hello"},
+        {"assistant", "Hi there"},
+        {"user",      "How are you?"},
+    };
+    return common_chat_apply_template(model, tmpl, msgs, true);
+}
+
+//
+// KV cache utils
+//
+
+void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) {
+    static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
+
+    printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
+        view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+    llama_kv_cache_view_cell * c_curr = view.cells;
+    llama_seq_id * cs_curr = view.cells_sequences;
+
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        if (i % row_size == 0) {
+            printf("\n%5d: ", i);
+        }
+        int seq_count = 0;
+        for (int j = 0; j < view.n_seq_max; j++) {
+            if (cs_curr[j] >= 0) { seq_count++; }
+        }
+        putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
+    }
+
+    printf("\n=== Done dumping\n");
+}
+
+void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) {
+    static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+    printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
+        view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+    std::unordered_map seqs;
+    llama_kv_cache_view_cell * c_curr = view.cells;
+    llama_seq_id * cs_curr = view.cells_sequences;
+
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        for (int j = 0; j < view.n_seq_max; j++) {
+            if (cs_curr[j] < 0) { continue; }
+            if (seqs.find(cs_curr[j]) == seqs.end()) {
+                if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+                const size_t sz = seqs.size();
+                seqs[cs_curr[j]] = sz;
+            }
+        }
+        if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+    }
+
+    printf("=== Sequence legend: ");
+    for (const auto & it : seqs) {
+        printf("%zu=%d, ", it.second, it.first);
+    }
+    printf("'+'=other sequence ids");
+
+    c_curr = view.cells;
+    cs_curr = view.cells_sequences;
+    for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) {
+        if (i % row_size == 0) {
+            printf("\n%5d: ", i);
+        }
+        for (int j = 0; j < view.n_seq_max; j++) {
+            if (cs_curr[j] >= 0) {
+                const auto & it = seqs.find(cs_curr[j]);
+                putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
+            } else {
+                putchar('.');
+            }
+        }
+        putchar(' ');
+    }
+
+    printf("\n=== Done dumping\n");
+}
+
+//
+// Embedding utils
+//
+
+void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
+    double sum = 0.0;
+
+    switch (embd_norm) {
+        case -1: // no normalisation
+            sum = 1.0;
+            break;
+        case 0: // max absolute
+            for (int i = 0; i < n; i++) {
+                if (sum < std::abs(inp[i])) {
+                    sum = std::abs(inp[i]);
+                }
+            }
+            sum /= 32760.0; // make an int16 range
+            break;
+        case 2: // euclidean
+            for (int i = 0; i < n; i++) {
+                sum += inp[i] * inp[i];
+            }
+            sum = std::sqrt(sum);
+            break;
+        default: // p-norm (euclidean is p-norm p=2)
+            for (int i = 0; i < n; i++) {
+                sum += std::pow(std::abs(inp[i]), embd_norm);
+            }
+            sum = std::pow(sum, 1.0 / embd_norm);
+            break;
+    }
+
+    const float norm = sum > 0.0 ? 1.0 / sum : 0.0f;
+
+    for (int i = 0; i < n; i++) {
+        out[i] = inp[i] * norm;
+    }
+}
+
+float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){
+    double sum  = 0.0;
+    double sum1 = 0.0;
+    double sum2 = 0.0;
+
+    for (int i = 0; i < n; i++) {
+        sum  += embd1[i] * embd2[i];
+        sum1 += embd1[i] * embd1[i];
+        sum2 += embd2[i] * embd2[i];
+    }
+
+    // Handle the case where one or both vectors are zero vectors
+    if (sum1 == 0.0 || sum2 == 0.0) {
+        if (sum1 == 0.0 && sum2 == 0.0) {
+            return 1.0f; // two zero vectors are similar
+        }
+        return 0.0f;
+    }
+
+    return sum / (sqrt(sum1) * sqrt(sum2));
+}
+
+//
+// Control vector utils
+//
+
+static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) {
+    common_control_vector_data result = { -1, {} };
+
+    ggml_context * ctx = nullptr;
+    struct gguf_init_params meta_gguf_params = {
+        /* .no_alloc = */ false,
+        /* .ctx      = */ &ctx,
+    };
+    struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
+    if (!ctx_gguf) {
+        LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
+        return result;
+    }
+
+    int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
+    if (n_tensors == 0) {
+        LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
+    }
+
+    for (int i = 0; i < n_tensors; i++) {
+        std::string name = gguf_get_tensor_name(ctx_gguf, i);
+
+        int layer_idx = -1;
+
+        // split on '.'
+        size_t dotpos = name.find('.');
+        if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") {
+            try {
+                layer_idx = std::stoi(name.substr(dotpos + 1));
+            } catch (...) {
+                layer_idx = -1;
+            }
+        }
+        if (layer_idx < 0) {
+            LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        } else if (layer_idx == 0) {
+            LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
+        if (tensor->type != GGML_TYPE_F32) {
+            LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+        if (ggml_n_dims(tensor) != 1) {
+            LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        if (result.n_embd == -1) {
+            result.n_embd = ggml_nelements(tensor);
+        } else if (ggml_nelements(tensor) != result.n_embd) {
+            LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        // extend if necessary - do not store data for layer 0 (it's not used)
+        result.data.resize(std::max(result.data.size(), static_cast(result.n_embd * layer_idx)), 0.0f);
+
+        const float * src = (const float *) tensor->data;
+        float * dst = result.data.data() + result.n_embd * (layer_idx - 1);  // layer 1 at [0]
+        for (int j = 0; j < result.n_embd; j++) {
+            dst[j] += src[j] * load_info.strength;  // allows multiple directions for same layer in same file
+        }
+
+    }
+
+    if (result.n_embd == -1) {
+        LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
+        result.data.clear();
+    }
+
+    gguf_free(ctx_gguf);
+    ggml_free(ctx);
+
+    return result;
+}
+
+common_control_vector_data common_control_vector_load(const std::vector & load_infos) {
+    common_control_vector_data result = { -1, {} };
+
+    for (const auto & info : load_infos) {
+        auto cur = common_control_vector_load_one(info);
+
+        if (cur.n_embd == -1) {
+            result.n_embd = -1;
+            break;
+        }
+        if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
+            LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
+            result.n_embd = -1;
+            break;
+        }
+
+        if (result.n_embd == -1) {
+            result = std::move(cur);
+        } else {
+            result.data.resize(std::max(result.data.size(), cur.data.size()), 0.0f);  // extend if necessary
+            for (size_t i = 0; i < cur.data.size(); i++) {
+                result.data[i] += cur.data[i];
+            }
+        }
+    }
+
+    if (result.n_embd == -1) {
+        LOG_ERR("%s: no valid control vector files passed\n", __func__);
+        result.data.clear();
+    }
+
+    return result;
+}
+
diff --git a/llama/common.h b/llama/common.h
new file mode 100644
index 000000000..db931490c
--- /dev/null
+++ b/llama/common.h
@@ -0,0 +1,675 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// Various helper functions and utilities
+
+#pragma once
+
+#include "llama-cpp.h"
+
+#include 
+#include 
+#include 
+
+#ifdef _WIN32
+#define DIRECTORY_SEPARATOR '\\'
+#else
+#define DIRECTORY_SEPARATOR '/'
+#endif // _WIN32
+
+#define die(msg)          do { fputs("error: " msg "\n", stderr);                exit(1); } while (0)
+#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
+
+#define print_build_info() do {                                                                     \
+    fprintf(stderr, "%s: build = %d (%s)\n",      __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);      \
+    fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);    \
+} while(0)
+
+#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
+
+struct common_lora_adapter_info {
+    std::string path;
+    float scale;
+
+    struct llama_lora_adapter * ptr;
+};
+
+using llama_tokens = std::vector;
+
+// build info
+extern int LLAMA_BUILD_NUMBER;
+extern const char * LLAMA_COMMIT;
+extern const char * LLAMA_COMPILER;
+extern const char * LLAMA_BUILD_TARGET;
+
+struct common_control_vector_load_info;
+
+//
+// CPU utils
+//
+
+struct cpu_params {
+    int      n_threads                   = -1;
+    bool     cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
+    bool     mask_valid                  = false;   // Default: any CPU
+    enum ggml_sched_priority  priority   = GGML_SCHED_PRIO_NORMAL;  // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
+    bool     strict_cpu                  = false;   // Use strict CPU placement
+    uint32_t poll                        = 50;      // Polling (busywait) level (0 - no polling, 100 - mostly polling)
+};
+
+int32_t cpu_get_num_physical_cores();
+int32_t cpu_get_num_math();
+
+//
+// Common params
+//
+
+enum llama_example {
+    LLAMA_EXAMPLE_COMMON,
+    LLAMA_EXAMPLE_SPECULATIVE,
+    LLAMA_EXAMPLE_MAIN,
+    LLAMA_EXAMPLE_INFILL,
+    LLAMA_EXAMPLE_EMBEDDING,
+    LLAMA_EXAMPLE_PERPLEXITY,
+    LLAMA_EXAMPLE_RETRIEVAL,
+    LLAMA_EXAMPLE_PASSKEY,
+    LLAMA_EXAMPLE_IMATRIX,
+    LLAMA_EXAMPLE_BENCH,
+    LLAMA_EXAMPLE_SERVER,
+    LLAMA_EXAMPLE_CVECTOR_GENERATOR,
+    LLAMA_EXAMPLE_EXPORT_LORA,
+    LLAMA_EXAMPLE_LLAVA,
+    LLAMA_EXAMPLE_LOOKUP,
+    LLAMA_EXAMPLE_PARALLEL,
+    LLAMA_EXAMPLE_TTS,
+
+    LLAMA_EXAMPLE_COUNT,
+};
+
+enum common_sampler_type {
+    COMMON_SAMPLER_TYPE_NONE        = 0,
+    COMMON_SAMPLER_TYPE_DRY         = 1,
+    COMMON_SAMPLER_TYPE_TOP_K       = 2,
+    COMMON_SAMPLER_TYPE_TOP_P       = 3,
+    COMMON_SAMPLER_TYPE_MIN_P       = 4,
+  //COMMON_SAMPLER_TYPE_TFS_Z       = 5,
+    COMMON_SAMPLER_TYPE_TYPICAL_P   = 6,
+    COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
+    COMMON_SAMPLER_TYPE_XTC         = 8,
+    COMMON_SAMPLER_TYPE_INFILL      = 9,
+    COMMON_SAMPLER_TYPE_PENALTIES   = 10,
+};
+
+// dimensionality reduction methods, used by cvector-generator
+enum dimre_method {
+    DIMRE_METHOD_PCA,
+    DIMRE_METHOD_MEAN,
+};
+
+// sampling parameters
+struct common_params_sampling {
+    uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
+
+    int32_t n_prev             = 64;    // number of previous tokens to remember
+    int32_t n_probs            = 0;     // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t min_keep           = 0;     // 0 = disabled, otherwise samplers should return at least min_keep tokens
+    int32_t top_k              = 40;    // <= 0 to use vocab size
+    float   top_p              = 0.95f; // 1.0 = disabled
+    float   min_p              = 0.05f; // 0.0 = disabled
+    float   xtc_probability    = 0.00f; // 0.0 = disabled
+    float   xtc_threshold      = 0.10f; // > 0.5 disables XTC
+    float   typ_p              = 1.00f; // typical_p, 1.0 = disabled
+    float   temp               = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
+    float   dynatemp_range     = 0.00f; // 0.0 = disabled
+    float   dynatemp_exponent  = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
+    int32_t penalty_last_n     = 64;    // last n tokens to penalize (0 = disable penalty, -1 = context size)
+    float   penalty_repeat     = 1.00f; // 1.0 = disabled
+    float   penalty_freq       = 0.00f; // 0.0 = disabled
+    float   penalty_present    = 0.00f; // 0.0 = disabled
+    float   dry_multiplier     = 0.0f;  // 0.0 = disabled;      DRY repetition penalty for tokens extending repetition:
+    float   dry_base           = 1.75f; // 0.0 = disabled;      multiplier * base ^ (length of sequence before token - allowed length)
+    int32_t dry_allowed_length = 2;     // tokens extending repetitions beyond this receive penalty
+    int32_t dry_penalty_last_n = -1;    // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
+    int32_t mirostat           = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+    float   mirostat_tau       = 5.00f; // target entropy
+    float   mirostat_eta       = 0.10f; // learning rate
+    bool    ignore_eos         = false;
+    bool    no_perf            = false; // disable performance metrics
+    bool    timing_per_token   = false;
+
+    std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"};     // default sequence breakers for DRY
+
+
+    std::vector samplers = {
+        COMMON_SAMPLER_TYPE_PENALTIES,
+        COMMON_SAMPLER_TYPE_DRY,
+        COMMON_SAMPLER_TYPE_TOP_K,
+        COMMON_SAMPLER_TYPE_TYPICAL_P,
+        COMMON_SAMPLER_TYPE_TOP_P,
+        COMMON_SAMPLER_TYPE_MIN_P,
+        COMMON_SAMPLER_TYPE_XTC,
+        COMMON_SAMPLER_TYPE_TEMPERATURE,
+    };
+
+    std::string grammar; // optional BNF-like grammar to constrain sampling
+
+    std::vector logit_bias; // logit biases to apply
+
+    // print the parameters into a string
+    std::string print() const;
+};
+
+struct common_params_speculative {
+    std::vector devices; // devices to use for offloading
+
+    int32_t n_ctx        =     0; // draft context size
+    int32_t n_max        =    16; // maximum number of tokens to draft during speculative decoding
+    int32_t n_min        =     5; // minimum number of draft tokens to use for speculative decoding
+    int32_t n_gpu_layers =    -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+    float   p_split      =  0.1f; // speculative decoding split probability
+    float   p_min        =  0.9f; // minimum speculative decoding probability (greedy)
+
+    struct cpu_params cpuparams;
+    struct cpu_params cpuparams_batch;
+
+    std::string model = ""; // draft model for speculative decoding                          // NOLINT
+};
+
+struct common_params_vocoder {
+    std::string hf_repo = ""; // HF repo                                                     // NOLINT
+    std::string hf_file = ""; // HF file                                                     // NOLINT
+
+    std::string model     = ""; // model path                                                // NOLINT
+    std::string model_url = ""; // model url to download                                     // NOLINT
+};
+
+struct common_params {
+    int32_t n_predict             =    -1; // new tokens to predict
+    int32_t n_ctx                 =  4096; // context size
+    int32_t n_batch               =  2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_ubatch              =   512; // physical batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_keep                =     0; // number of tokens to keep from initial prompt
+    int32_t n_chunks              =    -1; // max number of chunks to process (-1 = unlimited)
+    int32_t n_parallel            =     1; // number of parallel sequences to decode
+    int32_t n_sequences           =     1; // number of sequences to decode
+    int32_t grp_attn_n            =     1; // group-attention factor
+    int32_t grp_attn_w            =   512; // group-attention width
+    int32_t n_print               =    -1; // print token count every n tokens (-1 = disabled)
+    float   rope_freq_base        =  0.0f; // RoPE base frequency
+    float   rope_freq_scale       =  0.0f; // RoPE frequency scaling factor
+    float   yarn_ext_factor       = -1.0f; // YaRN extrapolation mix factor
+    float   yarn_attn_factor      =  1.0f; // YaRN magnitude scaling factor
+    float   yarn_beta_fast        = 32.0f; // YaRN low correction dim
+    float   yarn_beta_slow        =  1.0f; // YaRN high correction dim
+    int32_t yarn_orig_ctx         =     0; // YaRN original context length
+    float   defrag_thold          =  0.1f; // KV cache defragmentation threshold
+
+    // offload params
+    std::vector devices; // devices to use for offloading
+
+    int32_t n_gpu_layers      = -1;  // number of layers to store in VRAM (-1 - use default)
+    int32_t main_gpu          = 0;   // the GPU that is used for scratch and small tensors
+    float   tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
+
+    enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
+
+    struct cpu_params cpuparams;
+    struct cpu_params cpuparams_batch;
+
+    ggml_backend_sched_eval_callback cb_eval = nullptr;
+    void * cb_eval_user_data                 = nullptr;
+
+    ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
+
+    enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
+    enum llama_pooling_type      pooling_type      = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
+    enum llama_attention_type    attention_type    = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
+
+    struct common_params_sampling    sampling;
+    struct common_params_speculative speculative;
+    struct common_params_vocoder     vocoder;
+
+    std::string model                = ""; // model path                                                    // NOLINT
+    std::string model_alias          = ""; // model alias                                                   // NOLINT
+    std::string model_url            = ""; // model url to download                                         // NOLINT
+    std::string hf_token             = ""; // HF token                                                      // NOLINT
+    std::string hf_repo              = ""; // HF repo                                                       // NOLINT
+    std::string hf_file              = ""; // HF file                                                       // NOLINT
+    std::string prompt               = "";                                                                  // NOLINT
+    std::string prompt_file          = ""; // store the external prompt file name                           // NOLINT
+    std::string path_prompt_cache    = ""; // path to file for saving/loading prompt eval state             // NOLINT
+    std::string input_prefix         = ""; // string to prefix user inputs with                             // NOLINT
+    std::string input_suffix         = ""; // string to suffix user inputs with                             // NOLINT
+    std::string lookup_cache_static  = ""; // path of static ngram cache file for lookup decoding           // NOLINT
+    std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
+    std::string logits_file          = ""; // file for saving *all* logits                                  // NOLINT
+    std::string rpc_servers          = ""; // comma separated list of RPC servers                           // NOLINT
+
+    std::vector in_files;   // all input files
+    std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
+    std::vector kv_overrides;
+
+    bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
+    std::vector lora_adapters; // lora adapter path with user defined scale
+
+    std::vector control_vectors; // control vector with user defined scale
+
+    int32_t verbosity                  = 0;
+    int32_t control_vector_layer_start = -1; // layer range for control vector
+    int32_t control_vector_layer_end   = -1; // layer range for control vector
+
+    int32_t ppl_stride      = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
+    int32_t ppl_output_type = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
+                                     //                                       (which is more convenient to use for plotting)
+                                     //
+    bool   hellaswag        = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+    size_t hellaswag_tasks  = 400;   // number of tasks to use when computing the HellaSwag score
+
+    bool   winogrande       = false; // compute Winogrande score over random tasks from datafile supplied in prompt
+    size_t winogrande_tasks = 0;     // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
+
+    bool   multiple_choice  = false;  // compute TruthfulQA score over random tasks from datafile supplied in prompt
+    size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed
+
+    bool   kl_divergence    = false; // compute KL divergence
+
+    bool usage             = false; // print usage
+    bool use_color         = false; // use color to distinguish generations and inputs
+    bool special           = false; // enable special token output
+    bool interactive       = false; // interactive mode
+    bool interactive_first = false; // wait for user input immediately
+    bool conversation      = false; // conversation mode (does not print special tokens and suffix/prefix)
+    bool prompt_cache_all  = false; // save user input and generations to prompt cache
+    bool prompt_cache_ro   = false; // open the prompt cache read-only and do not update it
+
+    bool escape            = true;  // escape "\n", "\r", "\t", "\'", "\"", and "\\"
+    bool multiline_input   = false; // reverse the usage of `\`
+    bool simple_io         = false; // improves compatibility with subprocesses and limited consoles
+    bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
+    bool flash_attn        = false; // flash attention
+    bool no_perf           = false; // disable performance metrics
+    bool ctx_shift         = true;  // context shift on inifinite text generation
+
+    bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
+    bool logits_all        = false; // return logits for all tokens in the batch
+    bool use_mmap          = true;  // use mmap for faster loads
+    bool use_mlock         = false; // use mlock to keep model in memory
+    bool verbose_prompt    = false; // print prompt tokens before generation
+    bool display_prompt    = true;  // print prompt before generation
+    bool dump_kv_cache     = false; // dump the KV cache contents for debugging purposes
+    bool no_kv_offload     = false; // disable KV offloading
+    bool warmup            = true;  // warmup run
+    bool check_tensors     = false; // validate tensor data
+
+    ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
+    ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
+
+    // multimodal models (see examples/llava)
+    std::string mmproj = "";        // path to multimodal projector                                         // NOLINT
+    std::vector image; // path to image file(s)
+
+    // embedding
+    bool embedding         = false; // get only sentence embedding
+    int32_t embd_normalize = 2;     // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
+    std::string embd_out   = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
+    std::string embd_sep   = "\n";  // separator of embeddings
+    bool reranking         = false; // enable reranking support on server
+
+    // server params
+    int32_t port           = 8080;         // server listens on this network port
+    int32_t timeout_read   = 600;          // http read timeout in seconds
+    int32_t timeout_write  = timeout_read; // http write timeout in seconds
+    int32_t n_threads_http = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
+    int32_t n_cache_reuse  = 0;            // min chunk size to reuse from the cache via KV shifting
+
+    std::string hostname      = "127.0.0.1";
+    std::string public_path   = "";                                                                         // NOLINT
+    std::string chat_template = "";                                                                         // NOLINT
+    bool enable_chat_template = true;
+
+    std::vector api_keys;
+
+    std::string ssl_file_key  = "";                                                                         // NOLINT
+    std::string ssl_file_cert = "";                                                                         // NOLINT
+
+    // "advanced" endpoints are disabled by default for better security
+    bool webui            = true;
+    bool endpoint_slots   = false;
+    bool endpoint_props   = false; // only control POST requests, not GET
+    bool endpoint_metrics = false;
+
+    bool log_json = false;
+
+    std::string slot_save_path;
+
+    float slot_prompt_similarity = 0.5f;
+
+    // batched-bench params
+    bool is_pp_shared = false;
+
+    std::vector n_pp;
+    std::vector n_tg;
+    std::vector n_pl;
+
+    // retrieval params
+    std::vector context_files; // context files to embed
+
+    int32_t chunk_size = 64; // chunk size for context embedding
+
+    std::string chunk_separator = "\n"; // chunk separator for context embedding
+
+    // passkey params
+    int32_t n_junk = 250; // number of times to repeat the junk text
+    int32_t i_pos  = -1;  // position of the passkey in the junk text
+
+    // imatrix params
+    std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
+
+    int32_t n_out_freq  = 10; // output the imatrix every n_out_freq iterations
+    int32_t n_save_freq =  0; // save the imatrix every n_save_freq iterations
+    int32_t i_chunk     =  0; // start processing from this chunk
+
+    bool process_output = false; // collect data for the output tensor
+    bool compute_ppl    = true;  // whether to compute perplexity
+
+    // cvector-generator params
+    int n_pca_batch = 100;
+    int n_pca_iterations = 1000;
+    dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
+    std::string cvector_outfile       = "control_vector.gguf";
+    std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
+    std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
+
+    bool spm_infill = false; // suffix/prefix/middle pattern for infill
+
+    std::string lora_outfile = "ggml-lora-merged-f16.gguf";
+
+    // batched-bench params
+    bool batched_bench_output_jsonl = false;
+};
+
+// call once at the start of a program if it uses libcommon
+// initializes the logging system and prints info about the build
+void common_init();
+
+std::string common_params_get_system_info(const common_params & params);
+
+bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
+bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
+void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
+bool set_process_priority(enum ggml_sched_priority prio);
+
+//
+// String utils
+//
+
+#ifdef __GNUC__
+#ifdef __MINGW32__
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+#else
+#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
+#endif
+
+LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
+std::string string_format(const char * fmt, ...);
+
+std::string string_strip(const std::string & str);
+std::string string_get_sortable_timestamp();
+
+void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
+
+template
+static std::vector string_split(const std::string & str, char delim) {
+    static_assert(!std::is_same::value, "Please use the specialized version for std::string");
+    std::vector values;
+    std::istringstream str_stream(str);
+    std::string token;
+    while (std::getline(str_stream, token, delim)) {
+        T value;
+        std::istringstream token_stream(token);
+        token_stream >> value;
+        values.push_back(value);
+    }
+    return values;
+}
+
+template<>
+std::vector string_split(const std::string & input, char separator)
+{
+    std::vector parts;
+    size_t begin_pos = 0;
+    size_t separator_pos = input.find(separator);
+    while (separator_pos != std::string::npos) {
+        std::string part = input.substr(begin_pos, separator_pos - begin_pos);
+        parts.emplace_back(part);
+        begin_pos = separator_pos + 1;
+        separator_pos = input.find(separator, begin_pos);
+    }
+    parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
+    return parts;
+}
+
+static bool string_starts_with(const std::string & str,
+                               const std::string & prefix) {  // While we wait for C++20's std::string::starts_with...
+    return str.rfind(prefix, 0) == 0;
+}
+
+bool string_parse_kv_override(const char * data, std::vector & overrides);
+void string_process_escapes(std::string & input);
+
+std::string string_from(bool value);
+std::string string_from(const std::vector & values);
+std::string string_from(const struct llama_context * ctx, const std::vector & tokens);
+std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
+
+//
+// Filesystem utils
+//
+
+bool fs_validate_filename(const std::string & filename);
+bool fs_create_directory_with_parents(const std::string & path);
+
+std::string fs_get_cache_directory();
+std::string fs_get_cache_file(const std::string & filename);
+
+//
+// Model utils
+//
+
+// note: defines object's lifetime
+struct common_init_result {
+    llama_model_ptr   model;
+    llama_context_ptr context;
+
+    std::vector lora;
+};
+
+struct common_init_result     common_init_from_params(common_params & params);
+
+struct llama_model_params     common_model_params_to_llama  (      common_params & params);
+struct llama_context_params   common_context_params_to_llama(const common_params & params);
+struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
+
+struct llama_model * common_load_model_from_url(
+    const std::string & model_url,
+    const std::string & local_path,
+    const std::string & hf_token,
+    const struct llama_model_params & params);
+struct llama_model * common_load_model_from_hf(
+    const std::string & repo,
+    const std::string & remote_path,
+    const std::string & local_path,
+    const std::string & hf_token,
+    const struct llama_model_params & params);
+
+// clear LoRA adapters from context, then apply new list of adapters
+void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora);
+
+//
+// Batch utils
+//
+
+void common_batch_clear(struct llama_batch & batch);
+
+void common_batch_add(
+                 struct llama_batch & batch,
+                        llama_token   id,
+                          llama_pos   pos,
+    const std::vector & seq_ids,
+                               bool   logits);
+
+//
+// Token utils
+//
+
+// longest common prefix
+size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
+
+// longet common subsequence
+size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
+
+//
+// Vocab utils
+//
+
+// tokenizes a string into a vector of tokens
+// should work similar to Python's `tokenizer.encode`
+std::vector common_tokenize(
+  const struct llama_context * ctx,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special = false);
+
+std::vector common_tokenize(
+    const struct llama_model * model,
+           const std::string & text,
+                        bool   add_special,
+                        bool   parse_special = false);
+
+// tokenizes a token into a piece, optionally renders special/control tokens
+// should work similar to Python's `tokenizer.id_to_piece`
+std::string common_token_to_piece(
+        const struct llama_context * ctx,
+                       llama_token   token,
+                       bool          special = true);
+
+// detokenizes a vector of tokens into a string
+// should work similar to Python's `tokenizer.decode`
+// optionally renders special/control tokens
+std::string common_detokenize(
+                         llama_context * ctx,
+        const std::vector & tokens,
+                                  bool   special = true);
+
+//
+// Chat template utils
+//
+
+// same with llama_chat_message, but uses std::string
+struct common_chat_msg {
+    std::string role;
+    std::string content;
+};
+
+// Get the built-in chat template for the model. Return empty string if not present.
+std::string common_get_builtin_chat_template(const struct llama_model * model);
+
+// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
+bool common_chat_verify_template(const std::string & tmpl);
+
+// CPP wrapper for llama_chat_apply_template
+// If the built-in template is not supported, we default to chatml
+// If the custom "tmpl" is not supported, we throw an error
+std::string common_chat_apply_template(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector & chat,
+        bool add_ass);
+
+// Format single message, while taking into account the position of that message in chat history
+std::string common_chat_format_single(const struct llama_model * model,
+        const std::string & tmpl,
+        const std::vector & past_msg,
+        const common_chat_msg & new_msg,
+        bool add_ass);
+
+// Returns an example of formatted chat
+std::string common_chat_format_example(const struct llama_model * model,
+        const std::string & tmpl);
+
+//
+// KV cache utils
+//
+
+// Dump the KV cache view with the number of sequences per cell.
+void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80);
+
+// Dump the KV cache view showing individual sequences in each cell (long output).
+void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
+
+//
+// Embedding utils
+//
+
+// TODO: repace embd_norm with an enum
+void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
+
+float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);
+
+//
+// Control vector utils
+//
+
+struct common_control_vector_data {
+    int n_embd;
+
+    // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd
+    std::vector data;
+};
+
+struct common_control_vector_load_info {
+    float strength;
+
+    std::string fname;
+};
+
+// Load control vectors, scale each by strength, and add them together.
+// On error, returns {-1, empty}
+common_control_vector_data common_control_vector_load(const std::vector & load_infos);
+
+//
+// Split utils
+//
+
+namespace {
+
+const char * const LLM_KV_SPLIT_NO            = "split.no";
+const char * const LLM_KV_SPLIT_COUNT         = "split.count";
+const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
+
+}
diff --git a/llama/ggml-alloc.c b/llama/ggml-alloc.c
new file mode 100644
index 000000000..6ea83a901
--- /dev/null
+++ b/llama/ggml-alloc.c
@@ -0,0 +1,1063 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-alloc.h"
+#include "ggml-backend-impl.h"
+#include "ggml.h"
+#include "ggml-impl.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MAX_FREE_BLOCKS 256
+
+//#define GGML_ALLOCATOR_DEBUG
+
+//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
+#define AT_PRINTF(...)
+
+
+static bool ggml_is_view(const struct ggml_tensor * t) {
+    return t->view_src != NULL;
+}
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool ggml_op_can_inplace(enum ggml_op op) {
+    switch (op) {
+        case GGML_OP_SCALE:
+        case GGML_OP_DIAG_MASK_ZERO:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_LOG:
+        case GGML_OP_UNARY:
+        case GGML_OP_ROPE:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_SOFT_MAX:
+            return true;
+
+        default:
+            return false;
+    }
+}
+
+static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
+    assert(alignment && !(alignment & (alignment - 1))); // power of 2
+    size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
+    return offset + align;
+}
+
+// tallocr
+
+struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer) {
+    void * base = ggml_backend_buffer_get_base(buffer);
+    size_t align = ggml_backend_buffer_get_alignment(buffer);
+
+    assert(align && !(align & (align - 1))); // power of 2
+
+    struct ggml_tallocr talloc = (struct ggml_tallocr) {
+        /*.buffer    = */ buffer,
+        /*.base      = */ base,
+        /*.alignment = */ align,
+        /*.offset    = */ aligned_offset(base, 0, align),
+    };
+    return talloc;
+}
+
+void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor) {
+    size_t size = ggml_backend_buffer_get_alloc_size(talloc->buffer, tensor);
+    size = GGML_PAD(size, talloc->alignment);
+
+    if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
+        GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
+                __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
+        GGML_ABORT("not enough space in the buffer");
+    }
+
+    void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset;
+    talloc->offset += size;
+
+    assert(((uintptr_t)addr % talloc->alignment) == 0);
+
+    ggml_backend_tensor_alloc(talloc->buffer, tensor, addr);
+}
+
+// dynamic tensor allocator
+
+struct free_block {
+    size_t offset;
+    size_t size;
+};
+
+struct ggml_dyn_tallocr {
+    size_t alignment;
+    int n_free_blocks;
+    struct free_block free_blocks[MAX_FREE_BLOCKS];
+    size_t max_size;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    struct {
+        const struct ggml_tensor * tensor;
+        size_t offset;
+    } allocated_tensors[1024];
+#endif
+};
+
+#ifdef GGML_ALLOCATOR_DEBUG
+static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
+    for (int i = 0; i < 1024; i++) {
+        if (alloc->allocated_tensors[i].tensor == NULL) {
+            alloc->allocated_tensors[i].tensor = tensor;
+            alloc->allocated_tensors[i].offset = offset;
+            return;
+        }
+    }
+    GGML_ABORT("out of allocated_tensors");
+}
+static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) {
+    for (int i = 0; i < 1024; i++) {
+        if (alloc->allocated_tensors[i].offset == offset) {
+            alloc->allocated_tensors[i].tensor = NULL;
+            return;
+        }
+    }
+    GGML_ABORT("tried to free tensor %s not found\n", tensor->name);
+}
+#endif
+
+static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t size, const struct ggml_tensor * tensor) {
+    size = aligned_offset(NULL, size, alloc->alignment);
+
+    AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
+
+    size_t max_avail = 0;
+
+    // find the best fitting free block besides the last block
+    int best_fit_block = -1;
+    size_t best_fit_size = SIZE_MAX;
+    for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
+        struct free_block * block = &alloc->free_blocks[i];
+        max_avail = MAX(max_avail, block->size);
+        if (block->size >= size && block->size <= best_fit_size) {
+            best_fit_block = i;
+            best_fit_size = block->size;
+        }
+    }
+
+    if (best_fit_block == -1) {
+        // the last block is our last resort
+        struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
+        max_avail = MAX(max_avail, block->size);
+        if (block->size >= size) {
+            best_fit_block = alloc->n_free_blocks - 1;
+        } else {
+            // this should never happen
+            GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
+                    __func__, size, max_avail);
+            GGML_ABORT("not enough space in the buffer");
+        }
+    }
+
+    struct free_block * block = &alloc->free_blocks[best_fit_block];
+    size_t offset = block->offset;
+    block->offset = offset + size;
+    block->size -= size;
+    if (block->size == 0) {
+        // remove block if empty
+        alloc->n_free_blocks--;
+        for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
+            alloc->free_blocks[j] = alloc->free_blocks[j+1];
+        }
+    }
+
+    AT_PRINTF("block %d, offset %zu\n", best_fit_block, offset);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    add_allocated_tensor(alloc, offset, tensor);
+    size_t cur_max = offset + size;
+    if (cur_max > alloc->max_size) {
+        // sort allocated_tensors by offset
+        for (int i = 0; i < 1024; i++) {
+            for (int j = i + 1; j < 1024; j++) {
+                if (alloc->allocated_tensors[i].offset > alloc->allocated_tensors[j].offset) {
+                    const struct ggml_tensor * tmp_tensor = alloc->allocated_tensors[i].tensor;
+                    size_t tmp_offset = alloc->allocated_tensors[i].offset;
+                    alloc->allocated_tensors[i].tensor = alloc->allocated_tensors[j].tensor;
+                    alloc->allocated_tensors[i].offset = alloc->allocated_tensors[j].offset;
+                    alloc->allocated_tensors[j].tensor = tmp_tensor;
+                    alloc->allocated_tensors[j].offset = tmp_offset;
+                }
+            }
+        }
+        GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+        for (int i = 0; i < 1024; i++) {
+            if (alloc->allocated_tensors[i].tensor) {
+                GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
+                    alloc->allocated_tensors[i].offset,
+                    alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),
+                    ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
+            }
+        }
+        GGML_LOG_DEBUG("\n");
+    }
+#endif
+
+    alloc->max_size = MAX(alloc->max_size, offset + size);
+
+    return offset;
+
+    GGML_UNUSED(tensor);
+}
+
+// this is a very naive implementation, but for our case the number of free blocks should be very small
+static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, size_t size, const struct ggml_tensor * tensor) {
+    size = aligned_offset(NULL, size, alloc->alignment);
+
+    AT_PRINTF("%s: freeing %s at %zu (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, offset, size, alloc->n_free_blocks);
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    remove_allocated_tensor(alloc, offset, tensor);
+#endif
+
+    // see if we can merge with an existing block
+    for (int i = 0; i < alloc->n_free_blocks; i++) {
+        struct free_block * block = &alloc->free_blocks[i];
+        // check if ptr is at the end of the block
+        if (block->offset + block->size == offset) {
+            block->size += size;
+            // check if we can merge with the next block
+            if (i < alloc->n_free_blocks - 1 && block->offset + block->size == alloc->free_blocks[i+1].offset) {
+                block->size += alloc->free_blocks[i+1].size;
+                alloc->n_free_blocks--;
+                for (int j = i+1; j < alloc->n_free_blocks; j++) {
+                    alloc->free_blocks[j] = alloc->free_blocks[j+1];
+                }
+            }
+            return;
+        }
+        // check if ptr is at the beginning of the block
+        if (offset + size == block->offset) {
+            block->offset = offset;
+            block->size += size;
+            // check if we can merge with the previous block
+            if (i > 0 && alloc->free_blocks[i-1].offset + alloc->free_blocks[i-1].size == block->offset) {
+                alloc->free_blocks[i-1].size += block->size;
+                alloc->n_free_blocks--;
+                for (int j = i; j < alloc->n_free_blocks; j++) {
+                    alloc->free_blocks[j] = alloc->free_blocks[j+1];
+                }
+            }
+            return;
+        }
+    }
+    // otherwise, add a new block
+    GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
+    // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
+    int insert_pos = 0;
+    while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].offset < offset) {
+        insert_pos++;
+    }
+    // shift all blocks from insert_pos onward to make room for the new block
+    for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
+        alloc->free_blocks[i] = alloc->free_blocks[i-1];
+    }
+    // insert the new block
+    alloc->free_blocks[insert_pos].offset = offset;
+    alloc->free_blocks[insert_pos].size = size;
+    alloc->n_free_blocks++;
+
+    GGML_UNUSED(tensor);
+}
+
+static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) {
+    alloc->n_free_blocks = 1;
+    alloc->free_blocks[0].offset = 0;
+    alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
+    alloc->max_size = 0;
+
+#ifdef GGML_ALLOCATOR_DEBUG
+    for (int i = 0; i < 1024; i++) {
+        alloc->allocated_tensors[i].tensor = NULL;
+    }
+#endif
+}
+
+static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) {
+    struct ggml_dyn_tallocr * alloc = (struct ggml_dyn_tallocr *)malloc(sizeof(struct ggml_dyn_tallocr));
+
+    *alloc = (struct ggml_dyn_tallocr) {
+        /*.alignment     = */ alignment,
+        /*.n_free_blocks = */ 0,
+        /*.free_blocks   = */ {{0}},
+        /*.max_size      = */ 0,
+#ifdef GGML_ALLOCATOR_DEBUG
+        /*.allocated_tensors = */ {{0}},
+#endif
+    };
+
+    ggml_dyn_tallocr_reset(alloc);
+
+    return alloc;
+}
+
+static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
+    free(alloc);
+}
+
+static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
+    return alloc->max_size;
+}
+
+
+/////////////////////////////////////
+
+// graph allocator
+
+struct hash_node {
+    int n_children;
+    int n_views;
+    int buffer_id;
+    size_t offset; // offset within the buffer
+    bool allocated;
+};
+
+struct tensor_alloc {
+    int buffer_id;
+    size_t offset;
+    size_t size_max; // 0 = pre-allocated, unused, or view
+};
+
+struct leaf_alloc {
+    struct tensor_alloc leaf;
+};
+
+struct node_alloc {
+    struct tensor_alloc dst;
+    struct tensor_alloc src[GGML_MAX_SRC];
+};
+
+struct ggml_gallocr {
+    ggml_backend_buffer_type_t * bufts; // [n_buffers]
+    ggml_backend_buffer_t * buffers; // [n_buffers]
+    struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
+    int n_buffers;
+
+    struct ggml_hash_set hash_set;
+    struct hash_node * hash_values; // [hash_set.size]
+
+    struct node_alloc * node_allocs; // [n_nodes]
+    int n_nodes;
+
+    struct leaf_alloc * leaf_allocs; // [n_leafs]
+    int n_leafs;
+};
+
+ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs) {
+    ggml_gallocr_t galloc = (ggml_gallocr_t)calloc(1, sizeof(struct ggml_gallocr));
+    GGML_ASSERT(galloc != NULL);
+
+    galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t));
+    GGML_ASSERT(galloc->bufts != NULL);
+
+    galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
+    GGML_ASSERT(galloc->buffers != NULL);
+
+    galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
+    GGML_ASSERT(galloc->buf_tallocs != NULL);
+
+    for (int i = 0; i < n_bufs; i++) {
+        galloc->bufts[i] = bufts[i];
+        galloc->buffers[i] = NULL;
+
+        // check if the same buffer type is used multiple times and reuse the same allocator
+        for (int j = 0; j < i; j++) {
+            if (bufts[i] == bufts[j]) {
+                galloc->buf_tallocs[i] = galloc->buf_tallocs[j];
+                break;
+            }
+        }
+
+        if (galloc->buf_tallocs[i] == NULL) {
+            size_t alignment = ggml_backend_buft_get_alignment(bufts[i]);
+            galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment);
+        }
+    }
+    galloc->n_buffers = n_bufs;
+
+    return galloc;
+}
+
+ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft) {
+    return ggml_gallocr_new_n(&buft, 1);
+}
+
+void ggml_gallocr_free(ggml_gallocr_t galloc) {
+    if (galloc == NULL) {
+        return;
+    }
+
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        if (galloc->buffers != NULL) {
+            // skip if already freed
+            bool freed = false;
+            for (int j = 0; j < i; j++) {
+                if (galloc->buffers[j] == galloc->buffers[i]) {
+                    freed = true;
+                    break;
+                }
+            }
+            if (!freed) {
+                ggml_backend_buffer_free(galloc->buffers[i]);
+            }
+        }
+        if (galloc->buf_tallocs != NULL) {
+            // skip if already freed
+            bool freed = false;
+            for (int j = 0; j < i; j++) {
+                if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+                    freed = true;
+                    break;
+                }
+            }
+            if (!freed) {
+                ggml_dyn_tallocr_free(galloc->buf_tallocs[i]);
+            }
+        }
+    }
+
+    ggml_hash_set_free(&galloc->hash_set);
+    free(galloc->hash_values);
+    free(galloc->bufts);
+    free(galloc->buffers);
+    free(galloc->buf_tallocs);
+    free(galloc->node_allocs);
+    free(galloc->leaf_allocs);
+    free(galloc);
+}
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    size_t i = ggml_hash_find_or_insert(&galloc->hash_set, t);
+    return &galloc->hash_values[i];
+}
+
+static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    return ggml_gallocr_hash_get(galloc, t)->allocated;
+}
+
+static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
+    return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
+}
+
+static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
+    GGML_ASSERT(buffer_id >= 0);
+    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+
+    if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
+        hn->allocated = true;
+        assert(hn->offset == 0);
+
+        // try to reuse a parent's buffer (inplace)
+        if (ggml_op_can_inplace(node->op)) {
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                struct ggml_tensor * parent = node->src[i];
+                if (parent == NULL) {
+                    continue;
+                }
+
+                // if the node's data is external, then we cannot re-use it
+                if (!ggml_gallocr_is_own(galloc, parent)) {
+                    AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
+                    continue;
+                }
+
+                // outputs cannot be reused
+                if (parent->flags & GGML_TENSOR_FLAG_OUTPUT || (parent->view_src != NULL && parent->view_src->flags & GGML_TENSOR_FLAG_OUTPUT)) {
+                    AT_PRINTF("not reusing parent %s for %s as it is an output\n", parent->name, node->name);
+                    continue;
+                }
+
+                if (!ggml_are_same_layout(node, parent)) {
+                    AT_PRINTF("not reusing parent %s for %s as layouts are different\n", parent->name, node->name);
+                    continue;
+                }
+
+                struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+                if (p_hn->n_children == 1 && p_hn->n_views == 0) {
+                    if (ggml_is_view(parent)) {
+                        struct ggml_tensor * view_src = parent->view_src;
+                        struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+                        if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
+                            AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
+                            assert(view_src_hn->offset == p_hn->offset);
+                            hn->buffer_id = p_hn->buffer_id;
+                            hn->offset = p_hn->offset;
+                            p_hn->allocated = false; // avoid freeing the parent
+                            view_src_hn->allocated = false;
+                            return;
+                        }
+                    } else {
+                        AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
+                        hn->buffer_id = p_hn->buffer_id;
+                        hn->offset = p_hn->offset;
+                        p_hn->allocated = false; // avoid freeing the parent
+                        return;
+                    }
+                }
+            }
+        }
+        // allocate tensor from the buffer
+        struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+        ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+        size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+        size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node);
+        hn->buffer_id = buffer_id;
+        hn->offset = offset;
+    }
+}
+
+static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
+    // graph outputs are never freed
+    if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+        AT_PRINTF("not freeing output %s\n", node->name);
+        return;
+    }
+
+    struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+    size_t offset = hn->offset;
+    int buffer_id = hn->buffer_id;
+    struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id];
+    ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id];
+    size_t size = ggml_backend_buft_get_alloc_size(buft, node);
+    ggml_dyn_tallocr_free_tensor(alloc, offset, size, node);
+    hn->allocated = false;
+}
+
+static int get_node_buffer_id(const int * node_buffer_ids, int i) {
+    return node_buffer_ids ? node_buffer_ids[i] : 0;
+}
+
+static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
+    // clear hash tables
+    ggml_hash_set_reset(&galloc->hash_set);
+    memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size);
+
+    // allocate leafs
+    // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        ggml_gallocr_allocate_node(galloc, leaf, get_node_buffer_id(leaf_buffer_ids, i));
+    }
+
+    // count number of children and views
+    // allocate other graph inputs and leafs first to avoid overwriting them
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+
+        // TODO: better way to add external dependencies
+        // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to
+        // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node
+        // itself is never used and should not be considered a dependency
+        if (ggml_is_view(node) && node->op != GGML_OP_NONE) {
+            struct ggml_tensor * view_src = node->view_src;
+            ggml_gallocr_hash_get(galloc, view_src)->n_views += 1;
+        }
+
+        if (node->flags & GGML_TENSOR_FLAG_INPUT) {
+            ggml_gallocr_allocate_node(galloc, graph->nodes[i], get_node_buffer_id(node_buffer_ids, i));
+        }
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+
+            ggml_gallocr_hash_get(galloc, src)->n_children += 1;
+
+            // allocate explicit inputs
+            if (src->flags & GGML_TENSOR_FLAG_INPUT) {
+                ggml_gallocr_allocate_node(galloc, src, get_node_buffer_id(node_buffer_ids, i));
+            }
+        }
+    }
+
+    // allocate tensors
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        int buffer_id = get_node_buffer_id(node_buffer_ids, i);
+
+        // allocate parents (only leafs need to be allocated at this point)
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                continue;
+            }
+            ggml_gallocr_allocate_node(galloc, parent, buffer_id);
+        }
+
+        // allocate node
+        ggml_gallocr_allocate_node(galloc, node, buffer_id);
+
+        AT_PRINTF("exec: %s (%s) <= ", ggml_op_desc(node), node->name);
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                continue;
+            }
+            AT_PRINTF("%s", parent->name);
+            if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
+                AT_PRINTF(", ");
+            }
+        }
+        AT_PRINTF("\n");
+
+        // update parents
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * parent = node->src[j];
+            if (parent == NULL) {
+                continue;
+            }
+            struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
+            p_hn->n_children -= 1;
+
+            AT_PRINTF("parent %s: %d children, %d views, allocated: %d\n",
+                parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated);
+
+            if (p_hn->n_children == 0 && p_hn->n_views == 0) {
+                if (ggml_is_view(parent)) {
+                    struct ggml_tensor * view_src = parent->view_src;
+                    struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src);
+                    view_src_hn->n_views -= 1;
+                    AT_PRINTF("view_src %s: %d children, %d views\n",
+                        view_src->name, view_src_hn->n_children, view_src_hn->n_views);
+                    if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) {
+                        ggml_gallocr_free_node(galloc, view_src);
+                    }
+                }
+                else if (p_hn->allocated) {
+                    ggml_gallocr_free_node(galloc, parent);
+                }
+            }
+            AT_PRINTF("\n");
+        }
+    }
+}
+
+bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) {
+    size_t min_hash_size = graph->n_nodes + graph->n_leafs;
+    // add 25% margin to avoid hash collisions
+    min_hash_size += min_hash_size / 4;
+
+    // initialize hash table
+    if (galloc->hash_set.size < min_hash_size) {
+        ggml_hash_set_free(&galloc->hash_set);
+        galloc->hash_set = ggml_hash_set_new(min_hash_size);
+        GGML_ASSERT(galloc->hash_set.keys != NULL);
+
+        free(galloc->hash_values);
+        galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size);
+        GGML_ASSERT(galloc->hash_values != NULL);
+    }
+
+    // reset allocators
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        ggml_dyn_tallocr_reset(galloc->buf_tallocs[i]);
+    }
+
+    // allocate in hash table
+    ggml_gallocr_alloc_graph_impl(galloc, graph, node_buffer_ids, leaf_buffer_ids);
+
+    // set the node_allocs from the hash table
+    if (galloc->n_nodes < graph->n_nodes) {
+        free(galloc->node_allocs);
+        galloc->node_allocs = calloc(graph->n_nodes, sizeof(struct node_alloc));
+        GGML_ASSERT(galloc->node_allocs != NULL);
+    }
+    galloc->n_nodes = graph->n_nodes;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct node_alloc * node_alloc = &galloc->node_allocs[i];
+        if (node->view_src || node->data) {
+            node_alloc->dst.buffer_id = -1;
+            node_alloc->dst.offset = SIZE_MAX;
+            node_alloc->dst.size_max = 0;
+        } else {
+            struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
+            node_alloc->dst.buffer_id = hn->buffer_id;
+            node_alloc->dst.offset    = hn->offset;
+            node_alloc->dst.size_max  = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
+        }
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (!src || src->view_src || src->data) {
+                node_alloc->src[j].buffer_id = -1;
+                node_alloc->src[j].offset = SIZE_MAX;
+                node_alloc->src[j].size_max = 0;
+            } else {
+                struct hash_node * hn = ggml_gallocr_hash_get(galloc, src);
+                node_alloc->src[j].buffer_id = hn->buffer_id;
+                node_alloc->src[j].offset   = hn->offset;
+                node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src);
+            }
+        }
+    }
+    if (galloc->n_leafs < graph->n_leafs) {
+        free(galloc->leaf_allocs);
+        galloc->leaf_allocs = calloc(graph->n_leafs, sizeof(galloc->leaf_allocs[0]));
+        GGML_ASSERT(galloc->leaf_allocs != NULL);
+    }
+    galloc->n_leafs = graph->n_leafs;
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
+        if (leaf->view_src || leaf->data) {
+            galloc->leaf_allocs[i].leaf.buffer_id = -1;
+            galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
+            galloc->leaf_allocs[i].leaf.size_max = 0;
+        } else {
+            galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id;
+            galloc->leaf_allocs[i].leaf.offset = hn->offset;
+            galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf);
+        }
+    }
+
+    // reallocate buffers if needed
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        // if the buffer type is used multiple times, we reuse the same buffer
+        for (int j = 0; j < i; j++) {
+            if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) {
+                galloc->buffers[i] = galloc->buffers[j];
+                break;
+            }
+        }
+
+        size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0;
+        size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
+
+        // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
+        if (new_size > cur_size || galloc->buffers[i] == NULL) {
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+
+            ggml_backend_buffer_free(galloc->buffers[i]);
+            galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
+            if (galloc->buffers[i] == NULL) {
+                GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
+                return false;
+            }
+            ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
+        }
+    }
+
+    return true;
+}
+
+bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
+    return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL);
+}
+
+static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) {
+    int buffer_id = tensor_alloc->buffer_id;
+    assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
+
+    if (tensor->view_src != NULL) {
+        if (tensor->buffer == NULL) {
+            assert(tensor_alloc->offset == SIZE_MAX);
+            if (tensor->view_src->buffer == NULL) {
+                // this tensor was allocated without ggml-backend
+                return;
+            }
+            ggml_backend_view_init(tensor);
+        }
+    } else {
+        if (tensor->data == NULL) {
+            assert(tensor_alloc->offset != SIZE_MAX);
+            assert(ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max);
+            void * base = ggml_backend_buffer_get_base(galloc->buffers[buffer_id]);
+            void * addr = (char *)base + tensor_alloc->offset;
+            ggml_backend_tensor_alloc(galloc->buffers[buffer_id], tensor, addr);
+        } else {
+            if (tensor->buffer == NULL) {
+                // this tensor was allocated without ggml-backend
+                return;
+            }
+        }
+    }
+}
+
+static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
+    size_t node_size = 0;
+    if (!node->data && !node->view_src) {
+        GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
+        node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
+    }
+    return talloc->size_max >= node_size;
+}
+
+static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+    if (galloc->n_nodes != graph->n_nodes) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__);
+#endif
+        return true;
+    }
+
+    if (galloc->n_leafs != graph->n_leafs) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__);
+#endif
+        return true;
+    }
+
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct node_alloc * node_alloc = &galloc->node_allocs[i];
+
+        if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name);
+#endif
+            return true;
+        }
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
+#endif
+                return true;
+            }
+        }
+    }
+
+    return false;
+}
+
+bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
+    if (ggml_gallocr_needs_realloc(galloc, graph)) {
+        if (galloc->n_buffers == 1) {
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__);
+#endif
+            if (!ggml_gallocr_reserve(galloc, graph)) {
+                return false;
+            }
+        } else {
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
+#endif
+            return false;
+        }
+    }
+
+    // reset buffers
+    for (int i = 0; i < galloc->n_buffers; i++) {
+        if (galloc->buffers[i] != NULL) {
+            ggml_backend_buffer_reset(galloc->buffers[i]);
+        }
+    }
+
+    // allocate the graph tensors from the previous assignments
+    // leafs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i];
+        ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf);
+    }
+    // nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct node_alloc * node_alloc = &galloc->node_allocs[i];
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]);
+        }
+        ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst);
+    }
+
+    return true;
+}
+
+size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
+    GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
+
+    if (galloc->buffers[buffer_id] == NULL) {
+        return 0;
+    }
+
+    for (int i = 0; i < buffer_id; i++) {
+        if (galloc->buffers[i] == galloc->buffers[buffer_id]) {
+            // this buffer is the same as a previous one due to the same buffer type being used multiple times
+            // only return the buffer size the first time it appears to avoid double counting
+            return 0;
+        }
+    }
+
+    return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
+}
+
+// utils
+
+static bool alloc_tensor_range(struct ggml_context * ctx,
+        struct ggml_tensor * first, struct ggml_tensor * last,
+        ggml_backend_buffer_type_t buft, size_t size,
+        ggml_backend_buffer_t ** buffers, size_t * n_buffers) {
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
+    if (buffer == NULL) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
+#endif
+        for (size_t i = 0; i < *n_buffers; i++) {
+            ggml_backend_buffer_free((*buffers)[i]);
+        }
+        free(*buffers);
+        return false;
+    }
+
+    struct ggml_tallocr tallocr = ggml_tallocr_new(buffer);
+
+    for (struct ggml_tensor * t = first; t != last; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->data == NULL) {
+            if (t->view_src == NULL) {
+                ggml_tallocr_alloc(&tallocr, t);
+            } else if (t->buffer == NULL) {
+                ggml_backend_view_init(t);
+            }
+        } else {
+            if (t->view_src != NULL && t->buffer == NULL) {
+                // view of a pre-allocated tensor
+                ggml_backend_view_init(t);
+            }
+        }
+    }
+
+    *buffers = realloc(*buffers, sizeof(ggml_backend_buffer_t) * (*n_buffers + 1));
+    (*buffers)[(*n_buffers)++] = buffer;
+
+    return true;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
+    GGML_ASSERT(ggml_get_no_alloc(ctx) == true);
+
+    size_t alignment = ggml_backend_buft_get_alignment(buft);
+    size_t max_size = ggml_backend_buft_get_max_size(buft);
+
+    ggml_backend_buffer_t * buffers = NULL;
+    size_t n_buffers = 0;
+
+    size_t cur_buf_size = 0;
+    struct ggml_tensor * first = ggml_get_first_tensor(ctx);
+    for (struct ggml_tensor * t = first; t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        size_t this_size = 0;
+        if (t->data == NULL && t->view_src == NULL) {
+            this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
+        }
+
+        if (this_size > max_size) {
+            GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
+                    __func__, t->name,
+                    ggml_backend_buft_name(buft),
+                    this_size, max_size);
+            for (size_t i = 0; i < n_buffers; i++) {
+                ggml_backend_buffer_free(buffers[i]);
+            }
+            free(buffers);
+            return NULL;
+        }
+
+        if ((cur_buf_size + this_size) > max_size) {
+            // allocate tensors in the current buffer
+            if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
+                return NULL;
+            }
+            first = t;
+            cur_buf_size = this_size;
+        } else {
+            cur_buf_size += this_size;
+        }
+    }
+
+    // allocate remaining tensors
+    if (cur_buf_size > 0) {
+        if (!alloc_tensor_range(ctx, first, NULL, buft, cur_buf_size, &buffers, &n_buffers)) {
+            return NULL;
+        }
+    }
+
+    if (n_buffers == 0) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__);
+#endif
+        return NULL;
+    }
+
+    ggml_backend_buffer_t buffer;
+    if (n_buffers == 1) {
+        buffer = buffers[0];
+    } else {
+        buffer = ggml_backend_multi_buffer_alloc_buffer(buffers, n_buffers);
+    }
+    free(buffers);
+    return buffer;
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend) {
+    return ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_get_default_buffer_type(backend));
+}
diff --git a/llama/ggml-alloc.h b/llama/ggml-alloc.h
new file mode 100644
index 000000000..960ebf301
--- /dev/null
+++ b/llama/ggml-alloc.h
@@ -0,0 +1,102 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+typedef struct      ggml_backend_buffer * ggml_backend_buffer_t;
+typedef struct             ggml_backend * ggml_backend_t;
+
+// Tensor allocator
+struct ggml_tallocr {
+    ggml_backend_buffer_t buffer;
+    void * base;
+    size_t alignment;
+    size_t offset;
+};
+
+GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
+GGML_API void                ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
+
+// Graph allocator
+/*
+  Example usage:
+    ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
+
+    // optional: create a worst-case graph and reserve the buffers to avoid reallocations
+    ggml_gallocr_reserve(galloc, build_graph(max_batch));
+
+    // allocate the graph
+    struct ggml_cgraph * graph = build_graph(batch);
+    ggml_gallocr_alloc_graph(galloc, graph);
+
+    printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
+
+    // evaluate the graph
+    ggml_backend_graph_compute(backend, graph);
+*/
+
+// special tensor flags for use with the graph allocator:
+//   ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
+//   ggml_set_output(): output tensors are never freed and never overwritten
+
+typedef struct ggml_gallocr * ggml_gallocr_t;
+
+GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
+GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
+GGML_API void           ggml_gallocr_free(ggml_gallocr_t galloc);
+
+// pre-allocate buffers from a measure graph - does not allocate or modify the graph
+// call with a worst-case graph to avoid buffer reallocations
+// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
+// returns false if the buffer allocation failed
+GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+GGML_API bool ggml_gallocr_reserve_n(
+    ggml_gallocr_t galloc,
+    struct ggml_cgraph * graph,
+    const int * node_buffer_ids,
+    const int * leaf_buffer_ids);
+
+// automatic reallocation if the topology changes when using a single buffer
+// returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
+GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
+
+GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
+
+// Utils
+// Create a buffer and allocate all the tensors in a ggml_context
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
+GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/llama/ggml-backend-impl.h b/llama/ggml-backend-impl.h
new file mode 100644
index 000000000..37b592077
--- /dev/null
+++ b/llama/ggml-backend-impl.h
@@ -0,0 +1,282 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+// ggml-backend internal header
+
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    #define GGML_BACKEND_API_VERSION 1
+
+    //
+    // Backend buffer type
+    //
+
+    struct ggml_backend_buffer_type_i {
+        const char *          (*get_name)      (ggml_backend_buffer_type_t buft);
+        // allocate a buffer of this type
+        ggml_backend_buffer_t (*alloc_buffer)  (ggml_backend_buffer_type_t buft, size_t size);
+        // tensor alignment
+        size_t                (*get_alignment) (ggml_backend_buffer_type_t buft);
+        // (optional) max buffer size that can be allocated (defaults to SIZE_MAX)
+        size_t                (*get_max_size)  (ggml_backend_buffer_type_t buft);
+        // (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes)
+        size_t                (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
+        // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false)
+        bool                  (*is_host)       (ggml_backend_buffer_type_t buft);
+    };
+
+    struct ggml_backend_buffer_type {
+        struct ggml_backend_buffer_type_i  iface;
+        ggml_backend_dev_t device;
+        void * context;
+    };
+
+    //
+    // Backend buffer
+    //
+
+    struct ggml_backend_buffer_i {
+        // (optional) free the buffer
+        void         (*free_buffer)  (ggml_backend_buffer_t buffer);
+        // base address of the buffer
+        void *       (*get_base)     (ggml_backend_buffer_t buffer);
+        // (optional) initialize a tensor in the buffer (eg. add tensor extras)
+        void         (*init_tensor)  (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+        // tensor data access
+        void         (*memset_tensor)(ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
+        void         (*set_tensor)   (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void         (*get_tensor)   (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)
+        bool         (*cpy_tensor)   (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
+        // clear the entire buffer
+        void         (*clear)        (ggml_backend_buffer_t buffer, uint8_t value);
+        // (optional) reset any internal state due to tensor initialization, such as tensor extras
+        void         (*reset)        (ggml_backend_buffer_t buffer);
+    };
+
+    struct ggml_backend_buffer {
+        struct ggml_backend_buffer_i  iface;
+        ggml_backend_buffer_type_t    buft;
+        void * context;
+        size_t size;
+        enum ggml_backend_buffer_usage usage;
+    };
+
+    GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
+                   ggml_backend_buffer_type_t buft,
+            struct ggml_backend_buffer_i      iface,
+                   void *                     context,
+                   size_t                     size);
+
+    // do not use directly, use ggml_backend_tensor_copy instead
+    GGML_API bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    // multi-buffer
+    // buffer that contains a collection of buffers
+    GGML_API ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
+    GGML_API bool                  ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
+    GGML_API void                  ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+
+    //
+    // Backend (stream)
+    //
+
+    struct ggml_backend_i {
+        const char * (*get_name)(ggml_backend_t backend);
+
+        void (*free)(ggml_backend_t backend);
+
+        // (optional) asynchronous tensor data access
+        void (*set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+        void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+        bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
+
+        // (optional) complete all pending operations (required if the backend supports async operations)
+        void (*synchronize)(ggml_backend_t backend);
+
+        // (optional) graph plans (not used currently)
+        // compute graph with a plan
+        ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
+        void                      (*graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+        // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
+        void                      (*graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
+        // compute the graph with the plan
+        enum ggml_status          (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+        // compute graph (always async if supported by the backend)
+        enum ggml_status          (*graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+        // (optional) event synchronization
+        // record an event on this stream
+        void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
+        // wait for an event on on a different stream
+        void (*event_wait)  (ggml_backend_t backend, ggml_backend_event_t event);
+    };
+
+    struct ggml_backend {
+        ggml_guid_t guid;
+        struct ggml_backend_i iface;
+        ggml_backend_dev_t device;
+        void * context;
+    };
+
+    struct ggml_backend_event {
+        struct ggml_backend_device * device;
+        void * context;
+    };
+
+    //
+    // Backend device
+    //
+
+    // Note: if additional properties are needed, we should add a struct with all of them
+    //       the current functions to obtain the properties can remain, since they are more convenient for often used properties
+    struct ggml_backend_device_i {
+        // device name: short identifier for this device, such as "CPU" or "CUDA0"
+        const char * (*get_name)(ggml_backend_dev_t dev);
+
+        // device description: short informative description of the device, could be the model name
+        const char * (*get_description)(ggml_backend_dev_t dev);
+
+        // device memory in bytes
+        void         (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total);
+
+        // device type
+        enum ggml_backend_dev_type (*get_type)(ggml_backend_dev_t dev);
+
+        // device properties
+        void (*get_props)(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props);
+
+        // backend (stream) initialization
+        ggml_backend_t (*init_backend)(ggml_backend_dev_t dev, const char * params);
+
+        // preferred buffer type
+        ggml_backend_buffer_type_t (*get_buffer_type)(ggml_backend_dev_t dev);
+
+        // (optional) host buffer type (in system memory, typically this is a pinned memory buffer for faster transfers between host and device)
+        ggml_backend_buffer_type_t (*get_host_buffer_type)(ggml_backend_dev_t dev);
+
+        // (optional) buffer from pointer: create a buffer from a host pointer (useful for memory mapped models and importing data from other libraries)
+        ggml_backend_buffer_t (*buffer_from_host_ptr)(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size);
+
+        // check if the backend can compute an operation
+        bool (*supports_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);
+
+        // check if the backend can use tensors allocated in a buffer type
+        bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft);
+
+        // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer
+        // these should be expensive operations that may benefit from running on this backend instead of the CPU backend
+        bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);
+
+        // (optional) event synchronization
+        ggml_backend_event_t (*event_new)         (ggml_backend_dev_t dev);
+        void                 (*event_free)        (ggml_backend_dev_t dev, ggml_backend_event_t event);
+        void                 (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
+    };
+
+    struct ggml_backend_device {
+        struct ggml_backend_device_i iface;
+        ggml_backend_reg_t reg;
+        void * context;
+    };
+
+    //
+    // Backend (reg)
+    //
+
+    struct ggml_backend_reg_i {
+        const char * (*get_name)(ggml_backend_reg_t reg);
+
+        // enumerate available devices
+        size_t             (*get_device_count)(ggml_backend_reg_t reg);
+        ggml_backend_dev_t (*get_device)(ggml_backend_reg_t reg, size_t index);
+
+        // (optional) get a pointer to a function in the backend
+        // backends can add custom functions that are not part of the standard ggml-backend interface
+        void * (*get_proc_address)(ggml_backend_reg_t reg, const char * name);
+    };
+
+    struct ggml_backend_reg {
+        int api_version; // initialize to GGML_BACKEND_API_VERSION
+        struct ggml_backend_reg_i iface;
+        void * context;
+    };
+
+    // Internal backend registry API
+    GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
+    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
+
+    // Add backend dynamic loading support to the backend
+
+    // Initialize the backend
+    typedef ggml_backend_reg_t (*ggml_backend_init_t)(void);
+    // Optional: obtain a score for the backend based on the system configuration
+    // Higher scores are preferred, 0 means the backend is not supported in the current system
+    typedef int                (*ggml_backend_score_t)(void);
+
+#ifdef GGML_BACKEND_DL
+#    ifdef __cplusplus
+#        define GGML_BACKEND_DL_IMPL(reg_fn)                             \
+            extern "C" {                                                 \
+            GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \
+            }                                                            \
+            ggml_backend_reg_t ggml_backend_init(void) {                 \
+                return reg_fn();                                         \
+            }
+#        define GGML_BACKEND_DL_SCORE_IMPL(score_fn)       \
+            extern "C" {                                   \
+            GGML_BACKEND_API int ggml_backend_score(void); \
+            }                                              \
+            int ggml_backend_score(void) {                 \
+                return score_fn();                         \
+            }
+#    else
+#        define GGML_BACKEND_DL_IMPL(reg_fn)                              \
+            GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void);  \
+            ggml_backend_reg_t                  ggml_backend_init(void) { \
+                return reg_fn();                                          \
+            }
+#        define GGML_BACKEND_DL_SCORE_IMPL(score_fn)        \
+            GGML_BACKEND_API int ggml_backend_score(void);  \
+            int                  ggml_backend_score(void) { \
+                return score_fn();                          \
+            }
+#    endif
+#else
+#    define GGML_BACKEND_DL_IMPL(reg_fn)
+#    define GGML_BACKEND_DL_SCORE_IMPL(score_fn)
+#endif
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/llama/ggml-backend-reg.cpp b/llama/ggml-backend-reg.cpp
new file mode 100644
index 000000000..2ebc34398
--- /dev/null
+++ b/llama/ggml-backend-reg.cpp
@@ -0,0 +1,603 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef _WIN32
+#    define WIN32_LEAN_AND_MEAN
+#    ifndef NOMINMAX
+#        define NOMINMAX
+#    endif
+#    include 
+#elif defined(__APPLE__)
+#    include 
+#    include 
+#else
+#    include 
+#    include 
+#endif
+
+// Backend registry
+#ifdef GGML_USE_CPU
+#include "ggml-cpu.h"
+#endif
+
+#ifdef GGML_USE_CUDA
+#include "ggml-cuda.h"
+#endif
+
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_SYCL
+#include "ggml-sycl.h"
+#endif
+
+#ifdef GGML_USE_VULKAN
+#include "ggml-vulkan.h"
+#endif
+
+#ifdef GGML_USE_OPENCL
+#include "ggml-opencl.h"
+#endif
+
+#ifdef GGML_USE_BLAS
+#include "ggml-blas.h"
+#endif
+
+#ifdef GGML_USE_RPC
+#include "ggml-rpc.h"
+#endif
+
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
+#ifdef GGML_USE_KOMPUTE
+#include "ggml-kompute.h"
+#endif
+
+// disable C++17 deprecation warning for std::codecvt_utf8
+#if defined(__clang__)
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+static std::wstring utf8_to_utf16(const std::string & str) {
+    std::wstring_convert> converter;
+    return converter.from_bytes(str);
+}
+
+static std::string utf16_to_utf8(const std::wstring & str) {
+    std::wstring_convert> converter;
+    return converter.to_bytes(str);
+}
+
+#if defined(__clang__)
+#    pragma clang diagnostic pop
+#endif
+
+#ifdef _WIN32
+
+using dl_handle = std::remove_pointer_t;
+
+struct dl_handle_deleter {
+    void operator()(HMODULE handle) {
+        FreeLibrary(handle);
+    }
+};
+
+static dl_handle * dl_load_library(const std::wstring & path) {
+    // suppress error dialogs for missing DLLs
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    HMODULE handle = LoadLibraryW(path.c_str());
+
+    SetErrorMode(old_mode);
+
+    return handle;
+}
+
+static void * dl_get_sym(dl_handle * handle, const char * name) {
+    DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
+    SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
+
+    void * p = (void *) GetProcAddress(handle, name);
+
+    SetErrorMode(old_mode);
+
+    return p;
+}
+
+#else
+
+using dl_handle = void;
+
+struct dl_handle_deleter {
+    void operator()(void * handle) {
+        dlclose(handle);
+    }
+};
+
+static void * dl_load_library(const std::wstring & path) {
+    dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
+
+    return handle;
+}
+
+static void * dl_get_sym(dl_handle * handle, const char * name) {
+    return dlsym(handle, name);
+}
+
+#endif
+
+using dl_handle_ptr = std::unique_ptr;
+
+struct ggml_backend_reg_entry {
+    ggml_backend_reg_t reg;
+    dl_handle_ptr handle;
+};
+
+struct ggml_backend_registry {
+    std::vector backends;
+    std::vector devices;
+
+    ggml_backend_registry() {
+#ifdef GGML_USE_CUDA
+        register_backend(ggml_backend_cuda_reg());
+#endif
+#ifdef GGML_USE_METAL
+        register_backend(ggml_backend_metal_reg());
+#endif
+#ifdef GGML_USE_SYCL
+        register_backend(ggml_backend_sycl_reg());
+#endif
+#ifdef GGML_USE_VULKAN
+        register_backend(ggml_backend_vk_reg());
+#endif
+#ifdef GGML_USE_OPENCL
+        register_backend(ggml_backend_opencl_reg());
+#endif
+#ifdef GGML_USE_CANN
+        register_backend(ggml_backend_cann_reg());
+#endif
+// #ifdef GGML_USE_BLAS
+//         register_backend(ggml_backend_blas_reg());
+// #endif
+#ifdef GGML_USE_RPC
+        register_backend(ggml_backend_rpc_reg());
+#endif
+#ifdef GGML_USE_KOMPUTE
+        register_backend(ggml_backend_kompute_reg());
+#endif
+#ifdef GGML_USE_CPU
+        register_backend(ggml_backend_cpu_reg());
+#endif
+    }
+
+    ~ggml_backend_registry() {
+        // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources,
+        // since backend threads may still be running and accessing resources from the dynamic library
+        for (auto & entry : backends) {
+            if (entry.handle) {
+                entry.handle.release(); // NOLINT
+            }
+        }
+    }
+
+    void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) {
+        if (!reg) {
+            return;
+        }
+
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
+            __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
+#endif
+        backends.push_back({ reg, std::move(handle) });
+        for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
+            register_device(ggml_backend_reg_dev_get(reg, i));
+        }
+    }
+
+    void register_device(ggml_backend_dev_t device) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
+#endif
+        devices.push_back(device);
+    }
+
+    ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
+        dl_handle_ptr handle { dl_load_library(path) };
+        if (!handle) {
+            if (!silent) {
+                GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
+            }
+            return nullptr;
+        }
+
+        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+        if (score_fn && score_fn() == 0) {
+            if (!silent) {
+                GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
+            }
+            return nullptr;
+        }
+
+        auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
+        if (!backend_init_fn) {
+            if (!silent) {
+                GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
+            }
+            return nullptr;
+        }
+
+        ggml_backend_reg_t reg = backend_init_fn();
+        if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
+            if (!silent) {
+                if (!reg) {
+                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
+                } else {
+                    GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
+                        __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
+                }
+            }
+            return nullptr;
+        }
+
+        GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
+
+        register_backend(reg, std::move(handle));
+
+        return reg;
+    }
+
+    void unload_backend(ggml_backend_reg_t reg, bool silent) {
+        auto it = std::find_if(backends.begin(), backends.end(),
+                               [reg](const ggml_backend_reg_entry & entry) { return entry.reg == reg; });
+
+        if (it == backends.end()) {
+            if (!silent) {
+                GGML_LOG_ERROR("%s: backend not found\n", __func__);
+            }
+            return;
+        }
+
+        if (!silent) {
+            GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, ggml_backend_reg_name(reg));
+        }
+
+        // remove devices
+        devices.erase(
+            std::remove_if(devices.begin(), devices.end(),
+                            [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }),
+            devices.end());
+
+        // remove backend
+        backends.erase(it);
+    }
+};
+
+static ggml_backend_registry & get_reg() {
+    static ggml_backend_registry reg;
+    return reg;
+}
+
+// Internal API
+void ggml_backend_register(ggml_backend_reg_t reg) {
+    get_reg().register_backend(reg);
+}
+
+void ggml_backend_device_register(ggml_backend_dev_t device) {
+    get_reg().register_device(device);
+}
+
+// Backend (reg) enumeration
+static bool striequals(const char * a, const char * b) {
+    for (; *a && *b; a++, b++) {
+        if (std::tolower(*a) != std::tolower(*b)) {
+            return false;
+        }
+    }
+    return *a == *b;
+}
+
+size_t ggml_backend_reg_count() {
+    return get_reg().backends.size();
+}
+
+ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
+    GGML_ASSERT(index < ggml_backend_reg_count());
+    return get_reg().backends[index].reg;
+}
+
+ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
+    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
+        ggml_backend_reg_t reg = ggml_backend_reg_get(i);
+        if (striequals(ggml_backend_reg_name(reg), name)) {
+            return reg;
+        }
+    }
+    return nullptr;
+}
+
+// Device enumeration
+size_t ggml_backend_dev_count() {
+    return get_reg().devices.size();
+}
+
+ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
+    GGML_ASSERT(index < ggml_backend_dev_count());
+    return get_reg().devices[index];
+}
+
+ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+        if (striequals(ggml_backend_dev_name(dev), name)) {
+            return dev;
+        }
+    }
+    return nullptr;
+}
+
+ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+        if (ggml_backend_dev_type(dev) == type) {
+            return dev;
+        }
+    }
+    return nullptr;
+}
+
+// Convenience functions
+ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
+    ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
+    if (!dev) {
+        return nullptr;
+    }
+    return ggml_backend_dev_init(dev, params);
+}
+
+ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
+    ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
+    if (!dev) {
+        return nullptr;
+    }
+    return ggml_backend_dev_init(dev, params);
+}
+
+ggml_backend_t ggml_backend_init_best(void) {
+    ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
+    if (!dev) {
+        dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    }
+    if (!dev) {
+        return nullptr;
+    }
+    return ggml_backend_dev_init(dev, nullptr);
+}
+
+// Dynamic loading
+ggml_backend_reg_t ggml_backend_load(const char * path) {
+    return get_reg().load_backend(utf8_to_utf16(path), false);
+}
+
+void ggml_backend_unload(ggml_backend_reg_t reg) {
+    get_reg().unload_backend(reg, true);
+}
+
+static std::wstring get_executable_path() {
+#if defined(__APPLE__)
+    // get executable path
+    std::vector path;
+    uint32_t size;
+    while (true) {
+        size = path.size();
+        if (_NSGetExecutablePath(path.data(), &size) == 0) {
+            break;
+        }
+        path.resize(size);
+    }
+    std::string base_path(path.data(), size);
+    // remove executable name
+    auto last_slash = base_path.find_last_of('/');
+    if (last_slash != std::string::npos) {
+        base_path = base_path.substr(0, last_slash);
+    }
+    return utf8_to_utf16(base_path + "/");
+#elif defined(__linux__) || defined(__FreeBSD__)
+    std::string base_path = ".";
+    std::vector path(1024);
+    while (true) {
+        // get executable path
+#    if defined(__linux__)
+        ssize_t len = readlink("/proc/self/exe", path.data(), path.size());
+#    elif defined(__FreeBSD__)
+        ssize_t len = readlink("/proc/curproc/file", path.data(), path.size());
+#    endif
+        if (len == -1) {
+            break;
+        }
+        if (len < (ssize_t) path.size()) {
+            base_path = std::string(path.data(), len);
+            // remove executable name
+            auto last_slash = base_path.find_last_of('/');
+            if (last_slash != std::string::npos) {
+                base_path = base_path.substr(0, last_slash);
+            }
+            break;
+        }
+        path.resize(path.size() * 2);
+    }
+
+    return utf8_to_utf16(base_path + "/");
+#elif defined(_WIN32)
+    std::vector path(MAX_PATH);
+    DWORD len = GetModuleFileNameW(NULL, path.data(), path.size());
+    if (len == 0) {
+        return {};
+    }
+    std::wstring base_path(path.data(), len);
+    // remove executable name
+    auto last_slash = base_path.find_last_of('\\');
+    if (last_slash != std::string::npos) {
+        base_path = base_path.substr(0, last_slash);
+    }
+    return base_path + L"\\";
+#else
+    return {};
+#endif
+}
+
+static std::wstring backend_filename_prefix() {
+#ifdef _WIN32
+    return L"ggml-";
+#else
+    return L"libggml-";
+#endif
+}
+
+static std::wstring backend_filename_suffix() {
+#ifdef _WIN32
+    return L".dll";
+#else
+    return L".so";
+#endif
+}
+
+static std::wstring path_separator() {
+#ifdef _WIN32
+    return L"\\";
+#else
+    return L"/";
+#endif
+}
+
+static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) {
+    // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths
+     // TODO: search system paths
+    std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-";
+    std::vector search_paths;
+    if (user_search_path == nullptr) {
+        search_paths.push_back(L"." + path_separator());
+        search_paths.push_back(get_executable_path());
+    } else {
+        search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator());
+    }
+
+    int best_score = 0;
+    std::wstring best_path;
+
+    namespace fs = std::filesystem;
+    for (const auto & search_path : search_paths) {
+        if (!fs::exists(search_path)) {
+            continue;
+        }
+        fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
+        for (const auto & entry : dir_it) {
+            if (entry.is_regular_file()) {
+                std::wstring filename = entry.path().filename().wstring();
+                std::wstring ext = entry.path().extension().wstring();
+                if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+                    dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+                    if (!handle && !silent) {
+                        GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+                    }
+                    if (handle) {
+                        auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+                        if (score_fn) {
+                            int s = score_fn();
+#ifndef NDEBUG
+                            GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+#endif
+                            if (s > best_score) {
+                                best_score = s;
+                                best_path = entry.path().wstring();
+                            }
+                        } else {
+                            if (!silent) {
+                                GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    if (best_score == 0) {
+        // try to load the base backend
+        for (const auto & search_path : search_paths) {
+            std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix();
+            if (fs::exists(path)) {
+                return get_reg().load_backend(path, silent);
+            }
+        }
+        return nullptr;
+    }
+
+    return get_reg().load_backend(best_path, silent);
+}
+
+void ggml_backend_load_all() {
+    ggml_backend_load_all_from_path(nullptr);
+}
+
+void ggml_backend_load_all_from_path(const char * dir_path) {
+#ifdef NDEBUG
+    bool silent = true;
+#else
+    bool silent = false;
+#endif
+
+    ggml_backend_load_best("blas", silent, dir_path);
+    ggml_backend_load_best("cann", silent, dir_path);
+    ggml_backend_load_best("cuda", silent, dir_path);
+    ggml_backend_load_best("hip", silent, dir_path);
+    ggml_backend_load_best("kompute", silent, dir_path);
+    ggml_backend_load_best("metal", silent, dir_path);
+    ggml_backend_load_best("rpc", silent, dir_path);
+    ggml_backend_load_best("sycl", silent, dir_path);
+    ggml_backend_load_best("vulkan", silent, dir_path);
+    ggml_backend_load_best("opencl", silent, dir_path);
+    ggml_backend_load_best("musa", silent, dir_path);
+    ggml_backend_load_best("cpu", silent, dir_path);
+}
diff --git a/llama/ggml-backend.cpp b/llama/ggml-backend.cpp
new file mode 100644
index 000000000..3e11d73fc
--- /dev/null
+++ b/llama/ggml-backend.cpp
@@ -0,0 +1,2033 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// Note: porting this file to C++ is a work in progress
+
+#ifdef _WIN32
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#   define NOMINMAX
+#endif
+#include 
+#endif
+
+#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
+#include "ggml-alloc.h"
+#include "ggml-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifdef __APPLE__
+#include 
+#include 
+#endif
+
+
+// backend buffer type
+
+const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name(buft);
+}
+
+ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    if (size == 0) {
+        // return a dummy buffer for zero-sized allocations
+        return ggml_backend_buffer_init(buft, {}, NULL, 0);
+    }
+
+    return buft->iface.alloc_buffer(buft, size);
+}
+
+size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_alignment(buft);
+}
+
+size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
+    // get_max_size is optional, defaults to SIZE_MAX
+    if (buft->iface.get_max_size) {
+        return buft->iface.get_max_size(buft);
+    }
+    return SIZE_MAX;
+}
+
+size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) {
+    // get_alloc_size is optional, defaults to ggml_nbytes
+    if (buft->iface.get_alloc_size) {
+        size_t size = buft->iface.get_alloc_size(buft, tensor);
+        assert(size >= ggml_nbytes(tensor));
+        return size;
+    }
+    return ggml_nbytes(tensor);
+}
+
+bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
+    if (buft->iface.is_host) {
+        return buft->iface.is_host(buft);
+    }
+    return false;
+}
+
+ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {
+    return buft->device;
+}
+
+// backend buffer
+
+ggml_backend_buffer_t ggml_backend_buffer_init(
+               ggml_backend_buffer_type_t buft,
+        struct ggml_backend_buffer_i      iface,
+               void *                     context,
+               size_t                     size) {
+    ggml_backend_buffer_t buffer = new ggml_backend_buffer {
+        /* .interface = */ iface,
+        /* .buft      = */ buft,
+        /* .context   = */ context,
+        /* .size      = */ size,
+        /* .usage     = */ GGML_BACKEND_BUFFER_USAGE_ANY
+    };
+
+    return buffer;
+}
+
+const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_name(ggml_backend_buffer_get_type(buffer));
+}
+
+void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
+    if (buffer == NULL) {
+        return;
+    }
+
+    if (buffer->iface.free_buffer != NULL) {
+        buffer->iface.free_buffer(buffer);
+    }
+
+// TODO: this needs to be freed in cuda and hip backends because
+// the cuda backend implementation compiled with msvc
+#if !defined(GGML_USE_CUDA) && !defined(GGML_USE_HIP)
+    delete buffer;
+#endif
+}
+
+size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
+    return buffer->size;
+}
+
+void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
+    // get_base is optional if the buffer is zero-sized
+    if (buffer->size == 0) {
+        return NULL;
+    }
+
+    void * base = buffer->iface.get_base(buffer);
+
+    GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
+
+    return base;
+}
+
+void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    // init_tensor is optional
+    if (buffer->iface.init_tensor) {
+        buffer->iface.init_tensor(buffer, tensor);
+    }
+}
+
+void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    // clear is optional if the buffer is zero-sized
+    if (buffer->size == 0) {
+        return;
+    }
+
+    buffer->iface.clear(buffer, value);
+}
+
+size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer));
+}
+
+size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
+}
+
+bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
+    return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
+}
+
+void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+    buffer->usage = usage;
+
+    // FIXME: add a generic callback to the buffer interface
+    if (ggml_backend_buffer_is_multi_buffer(buffer)) {
+        ggml_backend_multi_buffer_set_usage(buffer, usage);
+    }
+}
+
+enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
+    return buffer->usage;
+}
+
+ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
+    return buffer->buft;
+}
+
+void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
+    if (buffer->iface.reset) {
+        buffer->iface.reset(buffer);
+    }
+}
+
+bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer;
+    if (dst_buf->iface.cpy_tensor) {
+        return dst_buf->iface.cpy_tensor(dst_buf, src, dst);
+    }
+    return false;
+}
+
+// backend
+
+ggml_guid_t ggml_backend_guid(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return NULL;
+    }
+    return backend->guid;
+}
+
+const char * ggml_backend_name(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return "NULL";
+    }
+    return backend->iface.get_name(backend);
+}
+
+void ggml_backend_free(ggml_backend_t backend) {
+    if (backend == NULL) {
+        return;
+    }
+
+    backend->iface.free(backend);
+}
+
+ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_dev_buffer_type(backend->device);
+}
+
+ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) {
+    return ggml_backend_buft_alloc_buffer(ggml_backend_get_default_buffer_type(backend), size);
+}
+
+size_t ggml_backend_get_alignment(ggml_backend_t backend) {
+    return ggml_backend_buft_get_alignment(ggml_backend_get_default_buffer_type(backend));
+}
+
+size_t ggml_backend_get_max_size(ggml_backend_t backend) {
+    return ggml_backend_buft_get_max_size(ggml_backend_get_default_buffer_type(backend));
+}
+
+void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    if (backend->iface.set_tensor_async == NULL) {
+        ggml_backend_tensor_set(tensor, data, offset, size);
+    } else {
+        backend->iface.set_tensor_async(backend, tensor, data, offset, size);
+    }
+}
+
+void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    if (backend->iface.get_tensor_async == NULL) {
+        ggml_backend_tensor_get(tensor, data, offset, size);
+    } else {
+        backend->iface.get_tensor_async(backend, tensor, data, offset, size);
+    }
+}
+
+void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    if (size == 0) {
+        return;
+    }
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+
+    buf->iface.set_tensor(buf, tensor, data, offset, size);
+}
+
+void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor);
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    if (size == 0) {
+        return;
+    }
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
+
+    buf->iface.get_tensor(buf, tensor, data, offset, size);
+}
+
+void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    if (size == 0) {
+        return;
+    }
+
+    GGML_ASSERT(buf != NULL && "tensor buffer not set");
+    GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
+    GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
+    GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not implemented by backend buffer");
+
+    buf->iface.memset_tensor(buf, tensor, value, offset, size);
+}
+
+void ggml_backend_synchronize(ggml_backend_t backend) {
+    if (backend->iface.synchronize == NULL) {
+        return;
+    }
+
+    backend->iface.synchronize(backend);
+}
+
+ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(backend->iface.graph_plan_create != NULL);
+
+    return backend->iface.graph_plan_create(backend, cgraph);
+}
+
+void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend->iface.graph_plan_free != NULL);
+
+    backend->iface.graph_plan_free(backend, plan);
+}
+
+enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
+
+    return backend->iface.graph_plan_compute(backend, plan);
+}
+
+enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    enum ggml_status err = ggml_backend_graph_compute_async(backend, cgraph);
+    ggml_backend_synchronize(backend);
+    return err;
+}
+
+enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    return backend->iface.graph_compute(backend, cgraph);
+}
+
+bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    return ggml_backend_dev_supports_op(backend->device, op);
+}
+
+bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_dev_supports_buft(backend->device, buft);
+}
+
+bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
+    return ggml_backend_dev_offload_op(backend->device, op);
+}
+
+ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
+    return backend->device;
+}
+
+// backend copy
+
+static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
+    if (a->type != b->type) {
+        return false;
+    }
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (a->ne[i] != b->ne[i]) {
+            return false;
+        }
+        if (a->nb[i] != b->nb[i]) {
+            return false;
+        }
+    }
+    return true;
+}
+
+void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+    if (src == dst) {
+        return;
+    }
+
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
+    } else if (ggml_backend_buffer_is_host(dst->buffer)) {
+        ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
+    } else if (!ggml_backend_buffer_copy_tensor(src, dst)) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
+#endif
+        size_t nbytes = ggml_nbytes(src);
+        void * data = malloc(nbytes);
+        ggml_backend_tensor_get(src, data, 0, nbytes);
+        ggml_backend_tensor_set(dst, data, 0, nbytes);
+        free(data);
+    }
+}
+
+void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
+
+    if (src == dst) {
+        return;
+    }
+
+    if (backend_dst->iface.cpy_tensor_async != NULL) {
+        if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
+            return;
+        }
+    }
+
+    // an async copy would normally happen after all the queued operations on both backends are completed
+    // to simulate the same behavior, we need to synchronize both backends first, and do a blocking copy
+    ggml_backend_synchronize(backend_src);
+    ggml_backend_synchronize(backend_dst);
+    ggml_backend_tensor_copy(src, dst);
+}
+
+// events
+
+ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device) {
+    // null device is allowed for the transition period to the device interface
+    if (device == NULL || device->iface.event_new == NULL) {
+        return NULL;
+    }
+    return device->iface.event_new(device);
+}
+
+void ggml_backend_event_free(ggml_backend_event_t event) {
+    if (event == NULL) {
+        return;
+    }
+    event->device->iface.event_free(event->device, event);
+}
+
+void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {
+    GGML_ASSERT(backend->iface.event_record != NULL);
+
+    backend->iface.event_record(backend, event);
+}
+
+void ggml_backend_event_synchronize(ggml_backend_event_t event) {
+    GGML_ASSERT(event->device->iface.event_synchronize);
+
+    event->device->iface.event_synchronize(event->device, event);
+}
+
+void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    GGML_ASSERT(backend->iface.event_wait != NULL);
+
+    backend->iface.event_wait(backend, event);
+}
+
+// Backend device
+
+const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
+    return device->iface.get_name(device);
+}
+
+const char * ggml_backend_dev_description(ggml_backend_dev_t device) {
+    return device->iface.get_description(device);
+}
+
+void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+    device->iface.get_memory(device, free, total);
+}
+
+enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
+    return device->iface.get_type(device);
+}
+
+void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
+    memset(props, 0, sizeof(*props));
+    device->iface.get_props(device, props);
+}
+
+ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {
+    return device->reg;
+}
+
+ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {
+    return device->iface.init_backend(device, params);
+}
+
+ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
+    return device->iface.get_buffer_type(device);
+}
+
+ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
+    if (device->iface.get_host_buffer_type == NULL) {
+        return NULL;
+    }
+
+    return device->iface.get_host_buffer_type(device);
+}
+
+ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {
+    return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);
+}
+
+bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+    return device->iface.supports_op(device, op);
+}
+
+bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {
+    return device->iface.supports_buft(device, buft);
+}
+
+bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
+    if (device->iface.offload_op != NULL) {
+        return device->iface.offload_op(device, op);
+    }
+
+    return false;
+}
+
+// Backend (reg)
+
+const char * ggml_backend_reg_name(ggml_backend_reg_t reg) {
+    return reg->iface.get_name(reg);
+}
+
+size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {
+    return reg->iface.get_device_count(reg);
+}
+
+ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {
+    return reg->iface.get_device(reg, index);
+}
+
+void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (!reg->iface.get_proc_address) {
+        return NULL;
+    }
+    return reg->iface.get_proc_address(reg, name);
+}
+
+// multi-buffer buffer
+
+struct ggml_backend_multi_buffer_context {
+    ggml_backend_buffer_t * buffers;
+    size_t n_buffers;
+};
+
+static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
+    for (size_t i = 0; i < ctx->n_buffers; i++) {
+        ggml_backend_buffer_free(ctx->buffers[i]);
+    }
+
+    free(ctx->buffers);
+    free(ctx);
+}
+
+static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
+    for (size_t i = 0; i < ctx->n_buffers; i++) {
+        ggml_backend_buffer_clear(ctx->buffers[i], value);
+    }
+}
+
+static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = {
+    /* .free_buffer     = */ ggml_backend_multi_buffer_free_buffer,
+    /* .get_base        = */ NULL,
+    /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ NULL,
+    /* .get_tensor      = */ NULL,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_multi_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) {
+    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) malloc(sizeof(struct ggml_backend_multi_buffer_context));
+    ctx->n_buffers = n_buffers;
+    ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t));
+
+    GGML_ASSERT(ctx->buffers != NULL);
+
+    size_t total_size = 0;
+    for (size_t i = 0; i < n_buffers; i++) {
+        ctx->buffers[i] = buffers[i];
+        total_size += ggml_backend_buffer_get_size(buffers[i]);
+    }
+
+    return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_i, ctx, total_size);
+}
+
+bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
+}
+
+void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
+    GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
+    ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
+    for (size_t i = 0; i < ctx->n_buffers; i++) {
+        ggml_backend_buffer_set_usage(ctx->buffers[i], usage);
+    }
+}
+
+// creates a copy of the tensor with the same memory layout
+static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
+    struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        dup->nb[i] = tensor->nb[i];
+    }
+    return dup;
+}
+
+static bool ggml_is_view_op(enum ggml_op op) {
+    return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
+}
+
+// scheduler
+
+#ifndef GGML_SCHED_MAX_BACKENDS
+#define GGML_SCHED_MAX_BACKENDS 16
+#endif
+
+#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
+#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
+#endif
+
+#ifndef GGML_SCHED_MAX_COPIES
+#define GGML_SCHED_MAX_COPIES 4
+#endif
+
+struct ggml_backend_sched_split {
+    int backend_id;
+    int i_start;
+    int i_end;
+    struct ggml_tensor * inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+    int n_inputs;
+    // graph view of this split
+    struct ggml_cgraph graph;
+};
+
+struct ggml_backend_sched {
+    bool is_reset; // true if the scheduler has been reset since the last graph split
+    bool is_alloc;
+
+    int n_backends;
+
+    ggml_backend_t backends[GGML_SCHED_MAX_BACKENDS];
+    ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS];
+    ggml_gallocr_t galloc;
+
+    // hash map of the nodes in the graph
+    struct ggml_hash_set  hash_set;
+    int                 * hv_tensor_backend_ids; // [hash_set.size]
+    struct ggml_tensor ** hv_tensor_copies;      // [hash_set.size][n_backends][n_copies]
+
+    int * node_backend_ids; // [graph_size]
+    int * leaf_backend_ids; // [graph_size]
+
+    int * prev_node_backend_ids; // [graph_size]
+    int * prev_leaf_backend_ids; // [graph_size]
+
+    // copy of the graph with modified inputs
+    struct ggml_cgraph graph;
+
+    // graph splits
+    struct ggml_backend_sched_split * splits;
+    int n_splits;
+    int splits_capacity;
+
+    // pipeline parallelism support
+    int n_copies;
+    int cur_copy;
+    ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
+    struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
+    int n_graph_inputs;
+
+    struct ggml_context * ctx;
+
+    ggml_backend_sched_eval_callback callback_eval;
+    void * callback_eval_user_data;
+
+    char * context_buffer;
+    size_t context_buffer_size;
+
+    int debug;
+};
+
+#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor)
+#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)]
+#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)]
+#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id)
+
+// returns the priority of the backend, lower id is higher priority
+static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (sched->backends[i] == backend) {
+            return i;
+        }
+    }
+    return -1;
+}
+
+static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) {
+    ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+    if (buffer == NULL) {
+        return -1;
+    }
+
+    // find highest prio backend that supports the buffer type and the op
+    for (int i = 0; i < sched->n_backends; i++) {
+        if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) &&
+            ggml_backend_supports_op(sched->backends[i], op)) {
+            return i;
+        }
+    }
+
+#ifndef NDEBUG
+    GGML_LOG_DEBUG("%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
+        __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);
+#endif
+
+    return -1;
+}
+
+#if 0
+#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
+static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
+#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
+#define GET_CAUSE(node) causes[hash_id(node)]
+#else
+#define SET_CAUSE(node, ...)
+#define GET_CAUSE(node) ""
+#endif
+
+// returns the backend that should be used for the node based on the current locations
+static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) {
+    // assign pre-allocated nodes to their backend
+    int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor);
+    if (cur_backend_id != -1) {
+        SET_CAUSE(tensor, "1.dst");
+        return cur_backend_id;
+    }
+
+    // view_src
+    if (tensor->view_src != NULL) {
+        cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor);
+        if (cur_backend_id != -1) {
+            SET_CAUSE(tensor, "1.vsrc");
+            return cur_backend_id;
+        }
+    }
+
+    if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
+        // since the tensor is pre-allocated, it cannot be moved to another backend
+        ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+        GGML_ABORT("pre-allocated tensor (%s) in a buffer (%s) that cannot run the operation (%s)", tensor->name, ggml_backend_buffer_name(buffer), ggml_op_name(tensor->op));
+    }
+
+    // graph input
+    if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
+        cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
+        SET_CAUSE(tensor, "1.inp");
+        return cur_backend_id;
+    }
+
+    // operations with weights are preferably run on the same backend as the weights
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        const struct ggml_tensor * src = tensor->src[i];
+        if (src == NULL) {
+            continue;
+        }
+        // skip ROPE since the rope freqs tensor is too small to choose a backend based on it
+        // not an ideal solution
+        if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+            int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
+            // check if a backend with higher prio wants to offload the op
+            if (src_backend_id == sched->n_backends - 1) {
+                for (int b = 0; b < src_backend_id; b++) {
+                    if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
+                        SET_CAUSE(tensor, "1.off");
+                        return b;
+                    }
+                }
+            }
+            SET_CAUSE(tensor, "1.wgt%d", i);
+            return src_backend_id;
+        }
+    }
+
+    return -1;
+}
+
+static char * fmt_size(size_t size) {
+    static char buffer[128];
+    if (size >= 1024*1024) {
+        snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
+    } else {
+        snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
+    }
+    return buffer;
+}
+
+static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    int cur_split = 0;
+    for (int i = 0; i < graph->n_nodes; i++) {
+        if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
+            ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
+            GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, ggml_backend_name(split_backend),
+                sched->splits[cur_split].n_inputs);
+            for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
+                if (j == 0) {
+                    GGML_LOG_DEBUG(": ");
+                }
+                GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                    fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
+            }
+            GGML_LOG_DEBUG("\n");
+            cur_split++;
+        }
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        if (sched->debug > 1) {
+            ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
+            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
+                fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    continue;
+                }
+                ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
+                GGML_LOG_DEBUG(" %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
+                    fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
+            }
+            GGML_LOG_DEBUG("\n");
+        }
+    }
+}
+
+static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) {
+    ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer;
+    ggml_backend_buffer_type_t buft = NULL;
+
+    if (buf) {
+        // the tensor is already allocated
+        buft = buf->buft;
+    } else {
+        // see if the tensor already has a backend assigned, and use the buffer type of that backend
+        int tensor_backend_id = tensor_backend_id(t);
+        if (tensor_backend_id == -1 && t->view_src) {
+            tensor_backend_id = tensor_backend_id(t->view_src);
+        }
+        if (tensor_backend_id != -1) {
+            buft = sched->bufts[tensor_backend_id];
+        }
+    }
+
+    return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft);
+}
+
+static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) {
+    if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) {
+        *node_backend_id = cur_backend_id;
+        SET_CAUSE(node, "2.sup");
+    }
+}
+
+// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
+static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    // reset splits
+    sched->n_splits = 0;
+    sched->n_graph_inputs = 0;
+    sched->is_reset = false;
+
+    struct ggml_init_params params = {
+        /* .mem_size =   */ sched->context_buffer_size,
+        /* .mem_buffer = */ sched->context_buffer,
+        /* .no_alloc =   */ true
+    };
+
+    ggml_free(sched->ctx);
+
+    sched->ctx = ggml_init(params);
+    if (sched->ctx == NULL) {
+        GGML_ABORT("%s: failed to initialize context\n", __func__);
+    }
+
+    // pass 1: assign backends to ops with pre-allocated inputs
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        int * leaf_backend_id = &tensor_backend_id(leaf);
+        // do not overwrite user assignments
+        if (*leaf_backend_id == -1) {
+            *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf);
+        }
+    }
+
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        int * node_backend_id = &tensor_backend_id(node);
+        // do not overwrite user assignments
+        if (*node_backend_id == -1) {
+            *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node);
+
+#if 0
+            // src
+            if (node->op == GGML_OP_NONE) {
+                continue;
+            }
+
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    continue;
+                }
+                int * src_backend_id = &tensor_backend_id(src);
+                if (*src_backend_id == -1) {
+                    *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src);
+                }
+            }
+#endif
+        }
+    }
+
+    // pass 2: expand current backend assignments
+    // assign the same backend to adjacent nodes
+    // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend)
+    // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
+    // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known
+    // expand gpu down
+    {
+        int cur_backend_id = -1;
+        for (int i = 0; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                if (*node_backend_id == sched->n_backends - 1) {
+                    // skip cpu (lowest prio backend)
+                    cur_backend_id = -1;
+                } else {
+                    cur_backend_id = *node_backend_id;
+                }
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+    // expand gpu up
+    {
+        int cur_backend_id = -1;
+        for (int i = graph->n_nodes - 1; i >= 0; i--) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                if (*node_backend_id == sched->n_backends - 1) {
+                    // skip cpu (lowest prio backend)
+                    cur_backend_id = -1;
+                } else {
+                    cur_backend_id = *node_backend_id;
+                }
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+    // expand rest down
+    {
+        int cur_backend_id = -1;
+        for (int i = 0; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                cur_backend_id = *node_backend_id;
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+    // expand rest up
+    {
+        int cur_backend_id = -1;
+        for (int i = graph->n_nodes - 1; i >= 0; i--) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+            int * node_backend_id = &tensor_backend_id(node);
+            if (*node_backend_id != -1) {
+                cur_backend_id = *node_backend_id;
+            } else if (cur_backend_id != -1) {
+                ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id);
+            }
+        }
+    }
+
+    // pass 3: upgrade nodes to higher prio backends with compatible buffer types
+    // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there
+    // however, we also need to verify that the sources are in compatible buffer types
+    // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph
+    // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same
+    // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU)
+    // additionally, set remaining unassigned nodes to the backend with the most supported inputs
+    // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        if (ggml_is_view_op(node->op)) {
+            continue;
+        }
+        int * node_backend_id = &tensor_backend_id(node);
+        if (*node_backend_id == -1) {
+            // unassigned node: find the backend with the most supported inputs
+            int n_supported_best = -1;
+            for (int b = 0; b < sched->n_backends; b++) {
+                if (ggml_backend_supports_op(sched->backends[b], node)) {
+                    int n_supported = 0;
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * src = node->src[j];
+                        if (src == NULL) {
+                            continue;
+                        }
+                        if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) {
+                            n_supported++;
+                        }
+                    }
+                    if (n_supported > n_supported_best) {
+                        n_supported_best = n_supported;
+                        *node_backend_id = b;
+                        SET_CAUSE(node, "3.best");
+                    }
+                }
+            }
+        } else {
+            // assigned node: upgrade to higher prio backend if possible
+            for (int b = 0; b < *node_backend_id; b++) {
+                if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) {
+                    bool supported = true;
+                    for (int j = 0; j < GGML_MAX_SRC; j++) {
+                        struct ggml_tensor * src = node->src[j];
+                        if (src == NULL) {
+                            continue;
+                        }
+                        if (!ggml_backend_sched_buffer_supported(sched, src, b)) {
+                            supported = false;
+                            break;
+                        }
+                    }
+                    if (supported) {
+                        *node_backend_id = b;
+                        SET_CAUSE(node, "3.upg");
+                        break;
+                    }
+                }
+            }
+        }
+    }
+
+    // pass 4: assign backends to remaining src from dst and view_src
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        int * cur_backend_id = &tensor_backend_id(node);
+        if (node->view_src != NULL && *cur_backend_id == -1) {
+            *cur_backend_id = tensor_backend_id(node->view_src);
+            SET_CAUSE(node, "4.vsrc");
+        }
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            struct ggml_tensor * src = node->src[j];
+            if (src == NULL) {
+                continue;
+            }
+            int * src_backend_id = &tensor_backend_id(src);
+            if (*src_backend_id == -1) {
+                if (src->view_src != NULL) {
+                    // views are always on the same backend as the source
+                    *src_backend_id = tensor_backend_id(src->view_src);
+                    SET_CAUSE(src, "4.vsrc");
+                } else {
+                    *src_backend_id = *cur_backend_id;
+                    SET_CAUSE(src, "4.cur");
+                }
+            }
+        }
+    }
+
+    // pass 5: split graph, find tensors that need to be copied
+    {
+        int i_split = 0;
+        struct ggml_backend_sched_split * split = &sched->splits[0];
+        // find the backend of the first split, skipping view ops
+        int i = 0;
+        for (; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+            if (!ggml_is_view_op(node->op)) {
+                split->backend_id = tensor_backend_id(node);
+                break;
+            }
+        }
+        split->i_start = 0;
+        split->n_inputs = 0;
+        int cur_backend_id = split->backend_id;
+        for (; i < graph->n_nodes; i++) {
+            struct ggml_tensor * node = graph->nodes[i];
+
+            if (ggml_is_view_op(node->op)) {
+                continue;
+            }
+
+            const int node_backend_id = tensor_backend_id(node);
+
+            assert(node_backend_id != -1); // all nodes should be assigned by now
+
+            // check if we should start a new split based on the sources of the current node
+            bool need_new_split = false;
+            if (node_backend_id == cur_backend_id && split->n_inputs > 0) {
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    struct ggml_tensor * src = node->src[j];
+                    if (src == NULL) {
+                        continue;
+                    }
+                    // check if a weight is on a different and incompatible backend
+                    // by starting a new split, the memory of the previously offloaded weights can be reused
+                    if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
+                        int src_backend_id = tensor_backend_id(src);
+                        if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
+                            need_new_split = true;
+                            break;
+                        }
+                    }
+                    // check if the split has too many inputs
+                    // FIXME: count the number of inputs instead of only checking when full
+                    if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) {
+                        const size_t id = hash_id(src);
+                        int src_backend_id = sched->hv_tensor_backend_ids[id];
+                        bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id);
+                        if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) {
+                            need_new_split = true;
+                            break;
+                        }
+                    }
+                }
+            }
+
+            if (node_backend_id != cur_backend_id || need_new_split) {
+                split->i_end = i;
+                i_split++;
+                if (i_split >= sched->splits_capacity) {
+                    sched->splits_capacity *= 2;
+                    sched->splits = (ggml_backend_sched_split *)
+                        realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
+                    GGML_ASSERT(sched->splits != NULL);
+                }
+                split = &sched->splits[i_split];
+                split->backend_id = node_backend_id;
+                split->i_start = i;
+                split->n_inputs = 0;
+                cur_backend_id = node_backend_id;
+            }
+
+            // find inputs that are not on the same backend
+            for (int j = 0; j < GGML_MAX_SRC; j++) {
+                struct ggml_tensor * src = node->src[j];
+                if (src == NULL) {
+                    continue;
+                }
+
+                size_t src_id = hash_id(src);
+                const int src_backend_id = sched->hv_tensor_backend_ids[src_id];
+                assert(src_backend_id != -1); // all inputs should be assigned by now
+
+                if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) {
+                    if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) {
+                        ggml_backend_t backend = sched->backends[src_backend_id];
+                        for (int c = 0; c < sched->n_copies; c++) {
+                            struct ggml_tensor * tensor_copy;
+                            if (c == sched->cur_copy) {
+                                tensor_copy = src; // use the original tensor as the current copy
+                            } else {
+                                tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                                ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+                            }
+                            if (sched->n_copies > 1) {
+                                ggml_set_input(tensor_copy);
+                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+                            }
+                            tensor_id_copy(src_id, src_backend_id, c) = tensor_copy;
+                            SET_CAUSE(tensor_copy, "4.cpy");
+                        }
+                        int n_graph_inputs = sched->n_graph_inputs++;
+                        GGML_ASSERT(n_graph_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+                        sched->graph_inputs[n_graph_inputs] = src;
+                    }
+                }
+
+                if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) {
+                    // create a copy of the input in the split's backend
+                    if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) {
+                        ggml_backend_t backend = sched->backends[cur_backend_id];
+                        for (int c = 0; c < sched->n_copies; c++) {
+                            struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
+                            ggml_format_name(tensor_copy, "%s#%s#%d", ggml_backend_name(backend), src->name, c);
+                            if (sched->n_copies > 1) {
+                                ggml_set_input(tensor_copy);
+                                ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor
+                            }
+                            tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy;
+                            SET_CAUSE(tensor_copy, "4.cpy");
+                        }
+                        int n_inputs = split->n_inputs++;
+                        GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS);
+                        split->inputs[n_inputs] = src;
+                    }
+                    node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy);
+                }
+            }
+        }
+        split->i_end = graph->n_nodes;
+        sched->n_splits = i_split + 1;
+    }
+
+    if (sched->debug) {
+        ggml_backend_sched_print_assignments(sched, graph);
+    }
+
+    // swap node_backend_ids and leaf _backend_ids with prevs
+    {
+        int * tmp = sched->node_backend_ids;
+        sched->node_backend_ids = sched->prev_node_backend_ids;
+        sched->prev_node_backend_ids = tmp;
+
+        tmp = sched->leaf_backend_ids;
+        sched->leaf_backend_ids = sched->prev_leaf_backend_ids;
+        sched->prev_leaf_backend_ids = tmp;
+    }
+
+    int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;
+    if (sched->graph.size < graph_size) {
+        sched->graph.size = graph_size;
+        sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
+        sched->graph.leafs = (ggml_tensor **) realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *));
+        GGML_ASSERT(sched->graph.nodes != NULL);
+        GGML_ASSERT(sched->graph.leafs != NULL);
+    }
+    sched->graph.n_nodes = 0;
+    sched->graph.n_leafs = 0;
+
+    struct ggml_cgraph * graph_copy = &sched->graph;
+
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &sched->splits[i];
+        split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
+
+        // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
+        for (int j = 0; j < split->n_inputs; j++) {
+            assert(graph_copy->size > (graph_copy->n_nodes + 1));
+
+            struct ggml_tensor * input = split->inputs[j];
+            const size_t input_id = hash_id(input);
+            struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy);
+
+            // add a dependency to the input source so that it is not freed before the copy is done
+            struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input);
+            input_dep->src[0] = input;
+            sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id];
+            graph_copy->nodes[graph_copy->n_nodes++] = input_dep;
+
+            // add a dependency to the input copy so that it is allocated at the start of the split
+            sched->node_backend_ids[graph_copy->n_nodes] = split->backend_id;
+            graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
+        }
+
+        for (int j = split->i_start; j < split->i_end; j++) {
+            assert(graph_copy->size > graph_copy->n_nodes);
+            sched->node_backend_ids[graph_copy->n_nodes] = tensor_backend_id(graph->nodes[j]);
+            graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
+        }
+    }
+
+    if (sched->n_copies > 1) {
+        // add input copies as leafs so that they are allocated first
+        for (int i = 0; i < sched->n_graph_inputs; i++) {
+            struct ggml_tensor * input = sched->graph_inputs[i];
+            size_t id = hash_id(input);
+            int backend_id = tensor_backend_id(input);
+            for (int c = 0; c < sched->n_copies; c++) {
+                struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
+                sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+                assert(graph_copy->size > graph_copy->n_leafs);
+                graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+            }
+        }
+
+        for (int i = 0; i < sched->n_splits; i++) {
+            struct ggml_backend_sched_split * split = &sched->splits[i];
+            int backend_id = split->backend_id;
+            for (int j = 0; j < split->n_inputs; j++) {
+                struct ggml_tensor * input = split->inputs[j];
+                size_t id = hash_id(input);
+                for (int c = 0; c < sched->n_copies; c++) {
+                    struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
+                    sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
+                    assert(graph_copy->size > graph_copy->n_leafs);
+                    graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
+                }
+            }
+        }
+    }
+
+    // add leafs from the original graph
+    for (int i = 0; i < graph->n_leafs; i++) {
+        struct ggml_tensor * leaf = graph->leafs[i];
+        sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
+        assert(graph_copy->size > graph_copy->n_leafs);
+        graph_copy->leafs[graph_copy->n_leafs++] = leaf;
+    }
+}
+
+static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
+    bool backend_ids_changed = false;
+    for (int i = 0; i < sched->graph.n_nodes; i++) {
+        if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] &&
+            sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) {
+            backend_ids_changed = true;
+            break;
+        }
+    }
+    if (!backend_ids_changed) {
+        for (int i = 0; i < sched->graph.n_leafs; i++) {
+            if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] &&
+                sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) {
+                backend_ids_changed = true;
+                break;
+            }
+        }
+    }
+
+    // allocate graph
+    if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
+        // the re-allocation may cause the split inputs to be moved to a different address
+        ggml_backend_sched_synchronize(sched);
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
+#endif
+        ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
+        if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
+            GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
+            return false;
+        }
+    }
+
+    return true;
+}
+
+static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
+    struct ggml_backend_sched_split * splits = sched->splits;
+
+    for (int i = 0; i < sched->n_splits; i++) {
+        struct ggml_backend_sched_split * split = &splits[i];
+        int split_backend_id = split->backend_id;
+        ggml_backend_t split_backend = sched->backends[split_backend_id];
+
+        // copy the input tensors to the split backend
+        for (int j = 0; j < split->n_inputs; j++) {
+            ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
+            struct ggml_tensor * input = split->inputs[j];
+            struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
+
+            if (input->flags & GGML_TENSOR_FLAG_INPUT) {
+                // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
+                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                    ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
+                } else {
+                    ggml_backend_synchronize(split_backend);
+                }
+                ggml_backend_tensor_copy(input, input_cpy);
+            } else {
+                // wait for the split backend to finish using the input before overwriting it
+                if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                    ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
+                } else {
+                    ggml_backend_synchronize(split_backend);
+                }
+                // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events
+                // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface
+                if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {
+                    ggml_backend_synchronize(input_backend);
+                    if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                        ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
+                    } else {
+                        ggml_backend_synchronize(split_backend);
+                    }
+                    ggml_backend_tensor_copy(input, input_cpy);
+                }
+            }
+        }
+
+        if (!sched->callback_eval) {
+            enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
+            if (ec != GGML_STATUS_SUCCESS) {
+                return ec;
+            }
+        } else {
+            // similar to ggml_backend_compare_graph_backend
+            for (int j0 = 0; j0 < split->graph.n_nodes; j0++) {
+                struct ggml_tensor * t = split->graph.nodes[j0];
+
+                // check if the user needs data from this node
+                bool need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+
+                int j1 = j0;
+
+                // determine the range [j0, j1] of nodes that can be computed together
+                while (!need && j1 < split->graph.n_nodes - 1) {
+                    t = split->graph.nodes[++j1];
+                    need = sched->callback_eval(t, true, sched->callback_eval_user_data);
+                }
+
+                struct ggml_cgraph gv = ggml_graph_view(&split->graph, j0, j1 + 1);
+
+                enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &gv);
+                if (ec != GGML_STATUS_SUCCESS) {
+                    return ec;
+                }
+
+                // TODO: pass backend to the callback, then the user can decide if they want to synchronize
+                ggml_backend_synchronize(split_backend);
+
+                if (need && !sched->callback_eval(t, false, sched->callback_eval_user_data)) {
+                    break;
+                }
+
+                j0 = j1;
+            }
+        }
+
+        // record the event of this copy
+        if (split->n_inputs > 0) {
+            if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
+                ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy], split_backend);
+            }
+        }
+    }
+
+    sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;
+
+    return GGML_STATUS_SUCCESS;
+}
+
+ggml_backend_sched_t ggml_backend_sched_new(
+        ggml_backend_t * backends,
+        ggml_backend_buffer_type_t * bufts,
+        int n_backends,
+        size_t graph_size,
+        bool parallel) {
+    GGML_ASSERT(n_backends > 0);
+    GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
+    GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
+
+    struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched));
+
+    const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG");
+    sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0;
+    sched->n_backends = n_backends;
+    sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1;
+
+    // initialize hash table
+    // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead)
+    sched->hash_set    = ggml_hash_set_new(graph_size);
+    sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
+    sched->hv_tensor_copies      = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
+
+    const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
+    const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
+    sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
+    sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
+    sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
+    sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
+
+    sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
+    sched->context_buffer = (char *) malloc(sched->context_buffer_size);
+
+    const int initial_splits_capacity = 16;
+    sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0]));
+    sched->splits_capacity = initial_splits_capacity;
+
+    for (int b = 0; b < n_backends; b++) {
+        sched->backends[b] = backends[b];
+        sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
+        GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
+
+        if (sched->n_copies > 1) {
+            for (int c = 0; c < sched->n_copies; c++) {
+                sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
+            }
+        }
+    }
+
+    sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
+
+    ggml_backend_sched_reset(sched);
+
+    return sched;
+}
+
+void ggml_backend_sched_free(ggml_backend_sched_t sched) {
+    if (sched == NULL) {
+        return;
+    }
+    for (int b = 0; b < sched->n_backends; b++) {
+        for (int c = 0; c < sched->n_copies; c++) {
+            ggml_backend_event_free(sched->events[b][c]);
+        }
+    }
+    ggml_gallocr_free(sched->galloc);
+    ggml_free(sched->ctx);
+    ggml_hash_set_free(&sched->hash_set);
+    free(sched->splits);
+    free(sched->hv_tensor_backend_ids);
+    free(sched->hv_tensor_copies);
+    free(sched->node_backend_ids);
+    free(sched->leaf_backend_ids);
+    free(sched->prev_node_backend_ids);
+    free(sched->prev_leaf_backend_ids);
+    free(sched->context_buffer);
+    free(sched->graph.nodes);
+    free(sched->graph.leafs);
+    free(sched);
+}
+
+void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
+    // reset state for the next run
+    if (!sched->is_reset) {
+        ggml_hash_set_reset(&sched->hash_set);
+        memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
+        memset(sched->hv_tensor_copies,       0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
+        sched->is_reset = true;
+    }
+    sched->is_alloc = false;
+}
+
+bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
+    GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
+
+    ggml_backend_sched_split_graph(sched, measure_graph);
+
+    ggml_backend_sched_synchronize(sched);
+
+    if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
+        return false;
+    }
+
+    ggml_backend_sched_reset(sched);
+
+    return true;
+}
+
+bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
+
+    ggml_backend_sched_split_graph(sched, graph);
+
+
+    if (!ggml_backend_sched_alloc_splits(sched)) {
+        return false;
+    }
+
+    sched->is_alloc = true;
+
+    return true;
+}
+
+enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    enum ggml_status err = ggml_backend_sched_graph_compute_async(sched, graph);
+    ggml_backend_sched_synchronize(sched);
+    return err;
+}
+
+enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
+    if (!sched->is_reset && !sched->is_alloc) {
+        ggml_backend_sched_reset(sched);
+    }
+
+    if (!sched->is_alloc) {
+        if (!ggml_backend_sched_alloc_graph(sched, graph)) {
+            return GGML_STATUS_ALLOC_FAILED;
+        }
+    }
+
+    return ggml_backend_sched_compute_splits(sched);
+}
+
+void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
+    for (int i = 0; i < sched->n_backends; i++) {
+        ggml_backend_synchronize(sched->backends[i]);
+    }
+}
+
+void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
+    sched->callback_eval = callback;
+    sched->callback_eval_user_data = user_data;
+}
+
+int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
+    return sched->n_splits;
+}
+
+int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
+    return sched->n_copies;
+}
+
+int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
+    return sched->n_backends;
+}
+
+ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
+    GGML_ASSERT(i >= 0 && i < sched->n_backends);
+    return sched->backends[i];
+}
+
+size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
+    int backend_index = ggml_backend_sched_backend_id(sched, backend);
+    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+
+    return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
+}
+
+void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
+    int backend_index = ggml_backend_sched_backend_id(sched, backend);
+    GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
+    tensor_backend_id(node) = backend_index;
+    SET_CAUSE(node, "usr");
+    sched->is_reset = false;
+}
+
+ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
+    int backend_index = tensor_backend_id(node);
+    if (backend_index == -1) {
+        return NULL;
+    }
+    return sched->backends[backend_index];
+}
+
+// utils
+
+void ggml_backend_view_init(struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->view_src != NULL);
+    GGML_ASSERT(tensor->view_src->buffer != NULL);
+    GGML_ASSERT(tensor->view_src->data != NULL);
+
+    tensor->buffer = tensor->view_src->buffer;
+    tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
+    ggml_backend_buffer_init_tensor(tensor->buffer, tensor);
+}
+
+void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
+    GGML_ASSERT(tensor->buffer == NULL);
+    GGML_ASSERT(tensor->data == NULL);
+    GGML_ASSERT(tensor->view_src == NULL);
+    GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer));
+    GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <=
+                (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer));
+
+    tensor->buffer = buffer;
+    tensor->data = addr;
+    ggml_backend_buffer_init_tensor(buffer, tensor);
+}
+
+static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies,
+    struct ggml_context * ctx_allocated, struct ggml_context * ctx_unallocated, struct ggml_tensor * src) {
+
+    GGML_ASSERT(src != NULL);
+    GGML_ASSERT(src->data && "graph must be allocated");
+
+    size_t id = ggml_hash_insert(&hash_set, src);
+    if (id == GGML_HASHSET_ALREADY_EXISTS) {
+        return node_copies[ggml_hash_find(&hash_set, src)];
+    }
+
+    struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src);
+    if (src->view_src != NULL) {
+        dst->view_src = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, src->view_src);
+        dst->view_offs = src->view_offs;
+    }
+    dst->op = src->op;
+    memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
+    ggml_set_name(dst, src->name);
+
+    // copy src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            continue;
+        }
+        dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
+    }
+
+    node_copies[id] = dst;
+    return dst;
+}
+
+static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) {
+    size_t id = ggml_hash_find(hash_set, src);
+    if (node_init[id]) {
+        return;
+    }
+    node_init[id] = true;
+
+    struct ggml_tensor * dst = node_copies[id];
+    if (dst->view_src != NULL) {
+        graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src);
+        ggml_backend_view_init(dst);
+    }
+    else {
+        ggml_backend_tensor_copy(src, dst);
+    }
+
+    // init src
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        struct ggml_tensor * s = src->src[i];
+        if (s == NULL) {
+            continue;
+        }
+        graph_copy_init_tensor(hash_set, node_copies, node_init, s);
+    }
+}
+
+struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
+    struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
+    struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
+    bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false),
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ true
+    };
+
+    struct ggml_context * ctx_allocated = ggml_init(params);
+    struct ggml_context * ctx_unallocated = ggml_init(params);
+
+    if (ctx_allocated == NULL || ctx_unallocated == NULL) {
+        GGML_LOG_ERROR("%s: failed to allocate context for graph copy\n", __func__);
+        ggml_hash_set_free(&hash_set);
+        free(node_copies);
+        free(node_init);
+        ggml_free(ctx_allocated);
+        ggml_free(ctx_unallocated);
+        return {
+            /* .buffer           = */ NULL,
+            /* .ctx_allocated    = */ NULL,
+            /* .ctx_unallocated  = */ NULL,
+            /* .graph            = */ NULL,
+        };
+    }
+
+    // dup nodes
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, node);
+    }
+
+    // allocate nodes
+    ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
+    if (buffer == NULL) {
+        GGML_LOG_ERROR("%s: failed to allocate buffer for graph copy\n", __func__);
+        ggml_hash_set_free(&hash_set);
+        free(node_copies);
+        free(node_init);
+        ggml_free(ctx_allocated);
+        ggml_free(ctx_unallocated);
+        return {
+            /* .buffer           = */ NULL,
+            /* .ctx_allocated    = */ NULL,
+            /* .ctx_unallocated  = */ NULL,
+            /* .graph            = */ NULL,
+        };
+    }
+
+    //printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
+
+    // copy data and init views
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        graph_copy_init_tensor(&hash_set, node_copies, node_init, node);
+    }
+
+    // build graph copy
+    struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false);
+    for (int i = 0; i < graph->n_nodes; i++) {
+        struct ggml_tensor * node = graph->nodes[i];
+        struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)];
+        graph_copy->nodes[i] = node_copy;
+    }
+    graph_copy->n_nodes = graph->n_nodes;
+
+    ggml_hash_set_free(&hash_set);
+    free(node_copies);
+    free(node_init);
+
+    return {
+        /* .buffer           = */ buffer,
+        /* .ctx_allocated    = */ ctx_allocated,
+        /* .ctx_unallocated  = */ ctx_unallocated,
+        /* .graph            = */ graph_copy,
+    };
+}
+
+void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
+    ggml_backend_buffer_free(copy.buffer);
+    ggml_free(copy.ctx_allocated);
+    ggml_free(copy.ctx_unallocated);
+}
+
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+    struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
+    if (copy.buffer == NULL) {
+        return false;
+    }
+
+    struct ggml_cgraph * g1 = graph;
+    struct ggml_cgraph * g2 = copy.graph;
+
+    assert(g1->n_nodes == g2->n_nodes);
+
+    for (int i = 0; i < g1->n_nodes; i++) {
+        //printf("eval %d/%d\n", i, g1->n_nodes);
+        struct ggml_tensor * t1 = g1->nodes[i];
+        struct ggml_tensor * t2 = g2->nodes[i];
+
+        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+
+        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+
+        ggml_backend_graph_compute(backend1, &g1v);
+        ggml_backend_graph_compute(backend2, &g2v);
+
+        if (ggml_is_view_op(t1->op)) {
+            continue;
+        }
+
+        // compare results, calculate rms etc
+        if (!callback(i, t1, t2, user_data)) {
+            break;
+        }
+    }
+
+    ggml_backend_graph_copy_free(copy);
+
+    return true;
+}
+
+// CPU backend - buffer
+
+static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
+    uintptr_t data = (uintptr_t)buffer->context;
+
+    // align the buffer
+    if (data % TENSOR_ALIGNMENT != 0) {
+        data = GGML_PAD(data, TENSOR_ALIGNMENT);
+    }
+
+    return (void *)data;
+}
+
+static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_aligned_free(buffer->context, buffer->size);
+}
+
+static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    memcpy((char *)tensor->data + offset, data, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        memcpy(dst->data, src->data, ggml_nbytes(src));
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    memset(buffer->context, value, buffer->size);
+}
+
+static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = {
+    /* .free_buffer     = */ ggml_backend_cpu_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cpu_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = {
+    /* .free_buffer     = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
+    /* .get_base        = */ ggml_backend_cpu_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_cpu_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_cpu_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cpu_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cpu_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cpu_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// CPU backend buffer type
+
+// this buffer type is defined here to make it available to all backends
+
+static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * data = ggml_aligned_malloc(size);
+
+    if (data == NULL) {
+        GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
+        return NULL;
+    }
+
+    return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, size);
+}
+
+static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return true;
+
+    GGML_UNUSED(buft);
+}
+
+ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
+        /* .iface   = */ {
+            /* .get_name         = */ ggml_backend_cpu_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
+        },
+        /* .device  = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type;
+}
+
+static const char * ggml_backend_cpu_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_Mapped";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
+        /* .iface   = */ {
+            /* .get_name         = */ ggml_backend_cpu_buffer_from_ptr_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_cpu_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type_get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_cpu_buffer_type_is_host,
+        },
+        /* .device  = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_cpu_buffer_type;
+}
+
+ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
+    GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
+    return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
+}
diff --git a/llama/ggml-backend.h b/llama/ggml-backend.h
new file mode 100644
index 000000000..b67a183f5
--- /dev/null
+++ b/llama/ggml-backend.h
@@ -0,0 +1,378 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+
+#ifdef GGML_BACKEND_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef GGML_BACKEND_BUILD
+#            define GGML_BACKEND_API __declspec(dllexport) extern
+#        else
+#            define GGML_BACKEND_API __declspec(dllimport) extern
+#        endif
+#    else
+#        define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
+#    endif
+#else
+#    define GGML_BACKEND_API extern
+#endif
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
+    typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
+    typedef struct ggml_backend_event * ggml_backend_event_t;
+    typedef struct ggml_backend * ggml_backend_t;
+    typedef void * ggml_backend_graph_plan_t;
+    typedef struct ggml_backend_reg * ggml_backend_reg_t;
+    typedef struct ggml_backend_device * ggml_backend_dev_t;
+
+
+    //
+    // Backend buffer type
+    //
+
+    GGML_API const char *          ggml_backend_buft_name          (ggml_backend_buffer_type_t buft);
+    GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer  (ggml_backend_buffer_type_t buft, size_t size);
+    GGML_API size_t                ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
+    GGML_API size_t                ggml_backend_buft_get_max_size  (ggml_backend_buffer_type_t buft);
+    GGML_API size_t                ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
+    GGML_API bool                  ggml_backend_buft_is_host       (ggml_backend_buffer_type_t buft);
+    GGML_API ggml_backend_dev_t    ggml_backend_buft_get_device    (ggml_backend_buffer_type_t buft);
+
+    //
+    // Backend buffer
+    //
+
+    enum ggml_backend_buffer_usage {
+        GGML_BACKEND_BUFFER_USAGE_ANY = 0,
+        GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
+        GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2,
+    };
+
+    GGML_API const char *                   ggml_backend_buffer_name          (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_free          (ggml_backend_buffer_t buffer);
+    GGML_API void *                         ggml_backend_buffer_get_base      (ggml_backend_buffer_t buffer);
+    GGML_API size_t                         ggml_backend_buffer_get_size      (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_init_tensor   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API size_t                         ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
+    GGML_API size_t                         ggml_backend_buffer_get_max_size  (ggml_backend_buffer_t buffer);
+    GGML_API size_t                         ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
+    GGML_API void                           ggml_backend_buffer_clear         (ggml_backend_buffer_t buffer, uint8_t value);
+    GGML_API bool                           ggml_backend_buffer_is_host       (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_set_usage     (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
+    GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage     (ggml_backend_buffer_t buffer);
+    GGML_API ggml_backend_buffer_type_t     ggml_backend_buffer_get_type      (ggml_backend_buffer_t buffer);
+    GGML_API void                           ggml_backend_buffer_reset         (ggml_backend_buffer_t buffer);
+
+    // tensor copy between different backends
+    GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    //
+    // Backend (stream)
+    //
+
+    GGML_API ggml_guid_t  ggml_backend_guid(ggml_backend_t backend);
+    GGML_API const char * ggml_backend_name(ggml_backend_t backend);
+    GGML_API void         ggml_backend_free(ggml_backend_t backend);
+
+    GGML_API ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend);
+    GGML_API ggml_backend_buffer_t      ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
+    GGML_API size_t                     ggml_backend_get_alignment(ggml_backend_t backend);
+    GGML_API size_t                     ggml_backend_get_max_size(ggml_backend_t backend);
+
+    GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+
+    // "offset" refers to the offset in tensor->data for setting/getting data
+    GGML_API void ggml_backend_tensor_set(      struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
+    GGML_API void ggml_backend_tensor_memset(   struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
+
+    GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
+
+    GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API void                      ggml_backend_graph_plan_free  (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+
+    GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
+    GGML_API enum ggml_status ggml_backend_graph_compute      (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+
+    // NOTE: will be removed, use device version instead
+    GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op);
+    GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
+    GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op);
+
+    // asynchronous copy
+    // the copy is performed after all the currently queued operations in backend_src
+    // backend_dst will wait for the copy to complete before performing other operations
+    // automatic fallback to sync copy if async is not supported
+    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
+
+    GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
+
+    //
+    // Events
+    //
+
+    GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device);
+    GGML_API void                 ggml_backend_event_free(ggml_backend_event_t event);
+    GGML_API void                 ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend);
+    GGML_API void                 ggml_backend_event_synchronize(ggml_backend_event_t event);
+    GGML_API void                 ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event);
+
+    //
+    // Backend device
+    //
+
+    enum ggml_backend_dev_type {
+        // CPU device using system memory
+        GGML_BACKEND_DEVICE_TYPE_CPU,
+        // GPU device using dedicated memory
+        GGML_BACKEND_DEVICE_TYPE_GPU,
+        // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
+        GGML_BACKEND_DEVICE_TYPE_ACCEL
+    };
+
+    // functionality supported by the device
+    struct ggml_backend_dev_caps {
+        // asynchronous operations
+        bool async;
+        // pinned host buffer
+        bool host_buffer;
+        // creating buffers from host ptr
+        bool buffer_from_host_ptr;
+        // event synchronization
+        bool events;
+    };
+
+    // all the device properties
+    struct ggml_backend_dev_props {
+        const char * name;
+        const char * description;
+        size_t memory_free;
+        size_t memory_total;
+        enum ggml_backend_dev_type type;
+        struct ggml_backend_dev_caps caps;
+    };
+
+    GGML_API const char *                  ggml_backend_dev_name(ggml_backend_dev_t device);
+    GGML_API const char *                  ggml_backend_dev_description(ggml_backend_dev_t device);
+    GGML_API void                          ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total);
+    GGML_API enum ggml_backend_dev_type    ggml_backend_dev_type(ggml_backend_dev_t device);
+    GGML_API void                          ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
+    GGML_API ggml_backend_reg_t            ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
+    GGML_API ggml_backend_t                ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
+    GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
+    GGML_API ggml_backend_buffer_type_t    ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
+    GGML_API ggml_backend_buffer_t         ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
+
+    GGML_API bool                          ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
+    GGML_API bool                          ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft);
+    GGML_API bool                          ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op);
+
+    //
+    // Backend (reg)
+    //
+
+    GGML_API const char *       ggml_backend_reg_name(ggml_backend_reg_t reg);
+    GGML_API size_t             ggml_backend_reg_dev_count(ggml_backend_reg_t reg);
+    GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index);
+    GGML_API void *             ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name);
+
+    // Common functions that may be obtained using ggml_backend_reg_get_proc_address
+
+    // Split buffer type for tensor parallelism
+    typedef ggml_backend_buffer_type_t   (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
+    // Set the number of threads for the backend
+    typedef void                         (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
+    // Get additional buffer types provided by the device (returns a NULL-terminated array)
+    typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device);
+    // Set the abort callback for the backend
+    typedef void                         (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data);
+    // Get a list of feature flags supported by the backend (returns a NULL-terminated array)
+    struct ggml_backend_feature {
+        const char * name;
+        const char * value;
+    };
+    typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);
+
+    //
+    // Backend registry
+    //
+
+    // Backend (reg) enumeration
+    GGML_API size_t             ggml_backend_reg_count(void);
+    GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
+    GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name);
+
+    // Device enumeration
+    GGML_API size_t             ggml_backend_dev_count(void);
+    GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index);
+    GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name);
+    GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type);
+
+    // Direct backend (stream) initialization
+    // = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params)
+    GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params);
+    // = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params)
+    GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params);
+    // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL)
+    GGML_API ggml_backend_t ggml_backend_init_best(void);
+
+    // Load a backend from a dynamic library and register it
+    GGML_API ggml_backend_reg_t ggml_backend_load(const char * path);
+    // Unload a backend if loaded dynamically and unregister it
+    GGML_API void               ggml_backend_unload(ggml_backend_reg_t reg);
+    // Load all known backends from dynamic libraries
+    GGML_API void               ggml_backend_load_all(void);
+    GGML_API void               ggml_backend_load_all_from_path(const char * dir_path);
+
+    //
+    // Backend scheduler
+    //
+
+    // The backend scheduler allows for multiple backend devices to be used together
+    // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
+    // The backends are selected based on:
+    // - the backend that supports the operation
+    // - the location of the pre-allocated tensors (e.g. the weights)
+    /*
+      Example usage:
+
+        // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned
+        // preferrably to run on the same backend as the buffer
+        ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+
+        sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
+
+        // initialize buffers from a max size graph (optional)
+        reserve_graph = build_graph(sched, max_batch_size);
+
+        // manually assign nodes to a backend (optional, should not be needed in most cases)
+        struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
+        ggml_backend_sched_set_tensor_backend(sched, node, backend_gpu);
+
+        ggml_backend_sched_reserve(sched, reserve_graph);
+
+        // compute
+        graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
+        for (int i = 0; i < 10; ++i) {
+            ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
+        }
+
+        // if there are graph inputs:
+        graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
+        ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
+        ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
+        ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
+        ggml_backend_sched_graph_compute(sched, graph); // execute the graph
+
+        // as an alternative to the above it is also possible to assign the inputs to a dedicated context and
+        // allocate them statically via ggml_backend_alloc_ctx_tensors
+    }
+    */
+
+    typedef struct ggml_backend_sched * ggml_backend_sched_t;
+
+    // Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback)
+    // when ask == true, the scheduler wants to know if the user wants to observe this node
+    // this allows the scheduler to batch nodes together in order to evaluate them in a single call
+    //
+    // when ask == false, the scheduler is passing the node tensor to the user for observation
+    // if the user returns false, the scheduler will cancel the graph compute
+    //
+    typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
+
+    // Initialize a backend scheduler, backends with low index are given priority over backends with high index
+    GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
+    GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched);
+
+    // Initialize backend buffers from a measure graph
+    GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success
+
+    GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i);
+
+    // Get the number of splits of the last graph
+    GGML_API int                  ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
+    GGML_API int                  ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched);
+
+    GGML_API size_t               ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
+
+    GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
+    GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
+
+    // Allocate and compute graph on the backend scheduler
+    GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
+    GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
+
+    // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
+    // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
+    // The correct way to use this API is to discard the deallocated tensors and create new ones.
+    GGML_API void                 ggml_backend_sched_reset(ggml_backend_sched_t sched);
+
+    // Set a callback to be called for each resulting node during graph compute
+    GGML_API void                 ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data);
+
+    //
+    // Utils
+    //
+
+    struct ggml_backend_graph_copy {
+        ggml_backend_buffer_t buffer;
+        struct ggml_context * ctx_allocated;
+        struct ggml_context * ctx_unallocated;
+        struct ggml_cgraph * graph;
+    };
+
+    // Copy a graph to a different backend
+    GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph);
+    GGML_API void                           ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy);
+
+    typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
+
+    // Compare the output of two backends
+    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+
+    // Tensor initialization
+    GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
+    GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
+
+    // CPU buffer types are always available
+    GGML_API ggml_backend_buffer_t      ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
+    GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/llama/ggml-blas.cpp b/llama/ggml-blas.cpp
new file mode 100644
index 000000000..44acf0bd4
--- /dev/null
+++ b/llama/ggml-blas.cpp
@@ -0,0 +1,547 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifdef GGML_USE_BLAS
+
+#include "ggml-impl.h"
+#include "ggml-blas.h"
+#include "ggml-backend-impl.h"
+
+#include 
+#include 
+#include 
+
+#if defined(GGML_BLAS_USE_ACCELERATE)
+#   include 
+#elif defined(GGML_BLAS_USE_MKL)
+#   include 
+#elif defined(GGML_BLAS_USE_BLIS)
+#   include 
+#elif defined(GGML_BLAS_USE_NVPL)
+#   include 
+#else
+#   include 
+#endif
+
+struct ggml_backend_blas_context {
+    int n_threads = GGML_DEFAULT_N_THREADS;
+    std::unique_ptr work_data;
+    size_t work_size = 0;
+#ifndef GGML_USE_OPENMP
+    std::vector> tasks;
+#endif
+};
+
+static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const int64_t ne_plane      = ne01*ne00;
+    const size_t  desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
+
+    if (ctx->work_size < desired_wsize) {
+        ctx->work_data.reset(new char[desired_wsize]);
+        ctx->work_size = desired_wsize;
+    }
+    void * wdata = ctx->work_data.get();
+
+    // convert src0 to float
+    if (type != GGML_TYPE_F32) {
+        const auto * type_traits = ggml_get_type_traits(type);
+        ggml_to_float_t const to_float = type_traits->to_float;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const void  *       x      = (char *)  src0->data + i02*nb02          + i03*nb03;
+                      float * const wplane = (float *) wdata      + i02*ne_plane      + i03*ne02*ne_plane;
+
+                const int min_cols_per_thread = 4096;
+                const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
+                const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
+
+#ifdef GGML_USE_OPENMP
+                #pragma omp parallel for num_threads(n_threads)
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                }
+#else
+                for (int i = 1; i < n_threads; i++) {
+                    const int64_t start =       i*ne01/n_threads;
+                    const int64_t end   = (i + 1)*ne01/n_threads;
+                    if (start < end) {
+                        ctx->tasks.push_back(std::async(std::launch::async, [=]() {
+                            for (int64_t i01 = start; i01 < end; i01++) {
+                                to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                            }
+                        }));
+                    }
+                }
+                {
+                    // reuse the current thread for the first task
+                    const int64_t start = 0;
+                    const int64_t end   = ne01/n_threads;
+                    for (int64_t i01 = start; i01 < end; i01++) {
+                        to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
+                    }
+                }
+#endif
+            }
+        }
+
+#ifndef GGML_USE_OPENMP
+        // wait for all tasks to finish
+        for (auto & task : ctx->tasks) {
+            task.get();
+        }
+        ctx->tasks.clear();
+#endif
+    }
+
+#if defined(OPENBLAS_VERSION)
+    openblas_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(GGML_BLAS_USE_BLIS)
+    bli_thread_set_num_threads(ctx->n_threads);
+#endif
+
+#if defined(GGML_BLAS_USE_NVPL)
+    nvpl_blas_set_num_threads(ctx->n_threads);
+#endif
+
+    for (int64_t i13 = 0; i13 < ne13; i13++) {
+        for (int64_t i12 = 0; i12 < ne12; i12++) {
+            const int64_t i03 = i13/r3;
+            const int64_t i02 = i12/r2;
+
+            const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
+            const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
+                  float * d = (float *) ((char *)  dst->data + i12*nb2  + i13*nb3);
+
+            if (type != GGML_TYPE_F32) {
+                x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
+            }
+
+            cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne1, ne01, ne10,
+                        1.0f,   y, ne10,
+                                x, ne00,
+                        0.0f,   d, ne01);
+        }
+    }
+}
+
+static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
+    // src0: (k,n)
+    // src1: (k,m)
+    // dst:  (m,n)
+    //
+    // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
+    // Also expressed as (major,minor)
+    // a: (m,k): so src1 transposed
+    // b: (k,n): so src0
+    // c: (m,n)
+    //
+    // However, if ggml_is_transposed(src1) is true, then
+    // src1->data already contains a transposed version, so sgemm mustn't
+    // transpose it further.
+
+    int n = src0->ne[0];
+    int k = src0->ne[1];
+    int m = src1->ne[0];
+
+    CBLAS_TRANSPOSE transposeA;
+    int lda;
+
+    if (!ggml_is_transposed(src1)) {
+        transposeA = CblasTrans;
+        lda = m;
+    } else {
+        transposeA = CblasNoTrans;
+        lda = k;
+    }
+
+    float * a = (float *) ((char *) src1->data);
+    float * b = (float *) ((char *) src0->data);
+    float * c = (float *) ((char *) dst->data);
+
+    cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
+
+    GGML_UNUSED(ctx);
+}
+
+// backend interface
+
+static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
+    return "BLAS";
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_blas_free(ggml_backend_t backend) {
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+    delete ctx;
+    delete backend;
+}
+
+static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        switch (node->op) {
+            case GGML_OP_MUL_MAT:
+                ggml_backend_blas_mul_mat(ctx, node);
+                break;
+
+            case GGML_OP_OUT_PROD:
+                ggml_backend_blas_out_prod(ctx, node);
+                break;
+
+            case GGML_OP_NONE:
+            case GGML_OP_RESHAPE:
+            case GGML_OP_VIEW:
+            case GGML_OP_PERMUTE:
+            case GGML_OP_TRANSPOSE:
+                break;
+
+            default:
+                GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+
+    GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i blas_backend_i = {
+    /* .get_name                = */ ggml_backend_blas_get_name,
+    /* .free                    = */ ggml_backend_blas_free,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_blas_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_blas_guid(void) {
+    static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_blas_init(void) {
+    ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_blas_guid(),
+        /* .interface = */ blas_backend_i,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
+        /* .context   = */ ctx,
+    };
+
+#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+    if (openblas_get_parallel() != OPENBLAS_OPENMP) {
+        GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
+    }
+#endif
+
+#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
+    GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
+#endif
+
+    return backend;
+}
+
+bool ggml_backend_is_blas(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
+}
+
+void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_blas(backend_blas));
+
+    ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
+    ctx->n_threads = n_threads;
+}
+
+// device interface
+
+static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
+    return "BLAS";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
+    #if defined(GGML_BLAS_USE_ACCELERATE)
+        return "Accelerate";
+    #elif defined(GGML_BLAS_USE_MKL)
+        return "MKL";
+    #elif defined(GGML_BLAS_USE_BLIS)
+        return "BLIS";
+    #elif defined(GGML_BLAS_USE_NVPL)
+        return "NVPL";
+    #elif defined(OPENBLAS_VERSION)
+        return "OpenBLAS";
+    #else
+        return "BLAS";
+    #endif
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // TODO
+    *free = 0;
+    *total = 0;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_ACCEL;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_blas_device_get_name(dev);
+    props->description = ggml_backend_blas_device_get_description(dev);
+    props->type        = ggml_backend_blas_device_get_type(dev);
+    ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
+    return ggml_backend_blas_init();
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+
+        case GGML_OP_MUL_MAT:
+        {
+            // BLAS usually is only faster for large matrices
+            const struct ggml_tensor * src0 = op->src[0];
+            const struct ggml_tensor * src1 = op->src[1];
+
+            const int64_t ne10 = src1->ne[0];
+
+            const int64_t ne0 = op->ne[0];
+            const int64_t ne1 = op->ne[1];
+
+            // TODO: find the optimal value
+            const int64_t min_batch = 32;
+
+            return ggml_is_contiguous(src0) &&
+                   ggml_is_contiguous(src1) &&
+                   src1->type == GGML_TYPE_F32 &&
+                   (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
+        }
+
+        case GGML_OP_OUT_PROD:
+            return op->src[0]->type == GGML_TYPE_F32 &&
+                   op->src[1]->type == GGML_TYPE_F32 &&
+                   ggml_is_matrix(src0) &&
+                   ggml_is_matrix(src1) &&
+                   ggml_is_contiguous(src0) &&
+                   (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
+
+        default:
+            return false;
+
+    }
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
+    /* .get_name             = */ ggml_backend_blas_device_get_name,
+    /* .get_description      = */ ggml_backend_blas_device_get_description,
+    /* .get_memory           = */ ggml_backend_blas_device_get_memory,
+    /* .get_type             = */ ggml_backend_blas_device_get_type,
+    /* .get_props            = */ ggml_backend_blas_device_get_props,
+    /* .init_backend         = */ ggml_backend_blas_device_init_backend,
+    /* .get_buffer_type      = */ ggml_backend_blas_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
+    /* .supports_op          = */ ggml_backend_blas_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_blas_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend reg interface
+
+static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
+    return "BLAS";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    static ggml_backend_device ggml_backend_blas_device = {
+        /* .iface   = */ ggml_backend_blas_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ nullptr,
+    };
+
+    return &ggml_backend_blas_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        return (void *)ggml_backend_blas_set_n_threads;
+    }
+    return NULL;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(name);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
+    /* .get_name         = */ ggml_backend_blas_reg_get_name,
+    /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_blas_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_blas_reg(void) {
+    static struct ggml_backend_reg ggml_backend_blas_reg = {
+        /* .api_version = */ GGML_BACKEND_API_VERSION,
+        /* .iface       = */ ggml_backend_blas_reg_i,
+        /* .context     = */ NULL,
+    };
+
+    return &ggml_backend_blas_reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_blas_reg)
+
+#endif // GGML_USE_BLAS
\ No newline at end of file
diff --git a/llama/ggml-blas.h b/llama/ggml-blas.h
new file mode 100644
index 000000000..f5fb9de21
--- /dev/null
+++ b/llama/ggml-blas.h
@@ -0,0 +1,51 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// backend API
+GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
+
+GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
+
+// number of threads used for conversion to float
+// for openblas and blis, this will also set the number of threads used for blas operations
+GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
+
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/llama/ggml-common.h b/llama/ggml-common.h
new file mode 100644
index 000000000..e227c13fb
--- /dev/null
+++ b/llama/ggml-common.h
@@ -0,0 +1,1879 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef GGML_COMMON_DECL
+
+#if defined(GGML_COMMON_DECL_C)
+#include 
+
+typedef uint16_t ggml_half;
+typedef uint32_t ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_CPP)
+#include 
+
+typedef uint16_t ggml_half;
+typedef uint32_t ggml_half2;
+
+// std-c++ allow anonymous unions but some compiler warn on it
+#define GGML_COMMON_AGGR_U data
+// std-c++ do not allow it.
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_METAL)
+#include 
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_CUDA)
+#if defined(GGML_COMMON_DECL_MUSA)
+#include 
+#else
+#include 
+#endif
+#include 
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_HIP)
+#include 
+#include 
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_SYCL)
+#include 
+#include 
+
+typedef sycl::half  ggml_half;
+typedef sycl::half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#endif
+
+#if defined(GGML_COMMON_DECL)
+
+#ifndef __cplusplus
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+#endif // __cplusplus
+
+// QK = number of values after dequantization
+// QK_K = super-block size
+
+#define QK_K 256
+#define K_SCALE_SIZE 12
+
+#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
+// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
+
+#define QI4_0 (QK4_0 / (4 * QR4_0))
+#define QR4_0 2
+
+#define QI4_1 (QK4_1 / (4 * QR4_1))
+#define QR4_1 2
+
+#define QI5_0 (QK5_0 / (4 * QR5_0))
+#define QR5_0 2
+
+#define QI5_1 (QK5_1 / (4 * QR5_1))
+#define QR5_1 2
+
+#define QI8_0 (QK8_0 / (4 * QR8_0))
+#define QR8_0 1
+
+#define QI8_1 (QK8_1 / (4 * QR8_1))
+#define QR8_1 1
+
+#define QI2_K (QK_K / (4*QR2_K))
+#define QR2_K 4
+
+#define QI3_K (QK_K / (4*QR3_K))
+#define QR3_K 4
+
+#define QI4_K (QK_K / (4*QR4_K))
+#define QR4_K 2
+
+#define QI5_K (QK_K / (4*QR5_K))
+#define QR5_K 2
+
+#define QI6_K (QK_K / (4*QR6_K))
+#define QR6_K 2
+
+#define QI2_XXS (QK_K / (4*QR2_XXS))
+#define QR2_XXS 4
+
+#define QI2_XS (QK_K / (4*QR2_XS))
+#define QR2_XS 4
+
+#define QI2_S (QK_K / (4*QR2_S))
+#define QR2_S 4
+
+#define QI3_XXS (QK_K / (4*QR3_XXS))
+#define QR3_XXS 4
+
+#define QI3_XS (QK_K / (4*QR3_XS))
+#define QR3_XS 4
+
+#define QI1_S (QK_K / (4*QR1_S))
+#define QR1_S 8
+
+#define QI1_M (QK_K / (4*QR1_M))
+#define QR1_M 8
+
+#define QI4_NL (QK4_NL / (4*QR4_NL))
+#define QR4_NL 2
+
+#define QI4_XS (QK_K / (4*QR4_XS))
+#define QR4_XS 2
+
+#define QI3_S (QK_K / (4*QR3_S))
+#define QR3_S 4
+
+#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
+
+#define QK4_0 32
+typedef struct {
+    ggml_half d;           // delta
+    uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half m; // min
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_half d;           // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half m; // min
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    ggml_half d;       // delta
+    int8_t  qs[QK8_0]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half s; // d * sum(qs[i])
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 ds;
+    } GGML_COMMON_AGGR_U;
+    int8_t qs[QK8_1]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
+
+//
+// Ternary quantization
+//
+
+// 1.6875 bpw
+typedef struct {
+    uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256)
+    uint8_t qh[QK_K/64]; // 4 elements per byte
+    ggml_half d;
+} block_tq1_0;
+static_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding");
+
+// 2.0625 bpw
+typedef struct {
+    uint8_t qs[QK_K/4]; // 2 bits per element
+    ggml_half d;
+} block_tq2_0;
+static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");
+
+//
+// Super-block quantization structures
+//
+
+// 2-bit quantization
+// weight is represented as x = a * q + b
+// 16 blocks of 16 elements each
+// Effectively 2.625 bits per weight
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+// 3-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 3.4375 bits per weight
+typedef struct {
+    uint8_t hmask[QK_K/8]; // quants - high bit
+    uint8_t qs[QK_K/4];    // quants - low 2 bits
+    uint8_t scales[12];    // scales, quantized with 6 bits
+    ggml_half d;           // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+
+// 4-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 4.5 bits per weight
+typedef struct {
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];           // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+
+// 5-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 5.5 bits per weight
+typedef struct {
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];           // quants, high bit
+    uint8_t qs[QK_K/2];           // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    ggml_half d;             // super-block scale
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+
+// This is only used for intermediate quantization and dot products
+typedef struct {
+    float   d;              // delta
+    int8_t  qs[QK_K];       // quants
+    int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+
+// (Almost) "true" 2-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 2.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_half d;
+    uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
+
+// 2.3125 bpw quants
+typedef struct {
+    ggml_half d;
+    uint16_t qs[QK_K/8];
+    uint8_t  scales[QK_K/32];
+} block_iq2_xs;
+static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
+
+// 2.5625 bpw quants
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t scales[QK_K/32];
+} block_iq2_s;
+static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
+
+// (Almost) "true" 3-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 3.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_half d;
+    uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
+// 3.4375 bpw
+#define IQ3S_N_SCALE QK_K/64
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
+
+// 1.5625 bpw
+typedef struct {
+    ggml_half d;
+    uint8_t  qs[QK_K/8];
+    uint16_t qh[QK_K/32];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
+// 1.75 bpw
+typedef struct {
+    uint8_t  qs[QK_K/8];      // grid index, low 8 bits
+    uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)
+    uint8_t  scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
+} block_iq1_m;
+static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
+
+// Used by IQ1_M quants
+typedef union {
+    ggml_half f16;
+    uint16_t  u16;
+} iq1m_scale_t;
+
+// Non-linear quants
+#define QK4_NL 32
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
+static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
+
+typedef struct {
+    ggml_half d;
+    uint16_t scales_h;
+    uint8_t  scales_l[QK_K/64];
+    uint8_t  qs[QK_K/2];
+} block_iq4_xs;
+static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+
+#endif // GGML_COMMON_DECL
+#endif // GGML_COMMON_DECL
+
+////////////////////////////////////////////////////////////////////////////////
+
+#ifndef GGML_COMMON_IMPL
+
+#if defined(GGML_COMMON_IMPL_C)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_CPP)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_METAL)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_SYCL)
+
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#endif
+
+#if defined(GGML_COMMON_IMPL)
+
+GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8)
+    1, 2, 4, 8, 16, 32, 64, 128
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
+      0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
+    144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
+    160,  33,  34, 163,  36, 165, 166,  39,  40, 169, 170,  43, 172,  45,  46, 175,
+     48, 177, 178,  51, 180,  53,  54, 183, 184,  57,  58, 187,  60, 189, 190,  63,
+    192,  65,  66, 195,  68, 197, 198,  71,  72, 201, 202,  75, 204,  77,  78, 207,
+     80, 209, 210,  83, 212,  85,  86, 215, 216,  89,  90, 219,  92, 221, 222,  95,
+     96, 225, 226,  99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+    240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+GGML_TABLE_END()
+
+//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
+GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
+    0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
+    0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
+    0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
+    0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
+    0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
+    0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
+    0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
+    0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
+    0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
+    0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
+    0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
+    0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
+    0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
+    0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
+    0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
+    0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
+    0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
+    0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
+    0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
+    0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
+    0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
+    0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
+    0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
+    0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
+    0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
+    0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
+    0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
+    0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
+    0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
+    0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
+    0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
+    0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
+GGML_TABLE_END()
+//#endif
+
+
+GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+    0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+    0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+    0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+    0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+    0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+    0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+    0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+    0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+    0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+    0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+    0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+    0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+    0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+    0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+    0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+    0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+    0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+    0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+    0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+    0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+    0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+    0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+    0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+    0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+    0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+    0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+    0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+    0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+    0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+    0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+    0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+    0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+    0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+    0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+    0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+    0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+    0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+    0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+    0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+    0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+    0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+    0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+    0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+    0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+    0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+    0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+    0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+    0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+    0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+    0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+    0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+    0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+    0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+    0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+    0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+    0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+    0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+    0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
+    0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
+    0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
+    0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
+    0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
+    0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
+    0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
+    0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
+    0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
+    0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
+    0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
+    0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
+    0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
+    0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
+    0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
+    0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
+    0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
+    0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
+    0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
+    0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
+    0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
+    0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
+    0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
+    0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
+    0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
+    0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
+    0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
+    0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
+    0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
+    0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
+    0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
+    0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
+    0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
+    0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
+    0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
+    0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
+    0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
+    0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
+    0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
+    0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
+    0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
+    0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
+    0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
+    0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
+    0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
+    0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
+    0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
+    0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
+    0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
+    0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
+    0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
+    0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
+    0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
+    0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
+    0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
+    0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
+    0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
+    0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
+    0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
+    0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
+    0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
+    0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
+    0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
+    0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
+    0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
+    0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
+    0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
+    0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
+    0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
+    0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
+    0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
+    0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
+    0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
+    0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
+    0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
+    0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
+    0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
+    0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
+    0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
+    0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
+    0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
+    0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
+    0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
+    0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
+    0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
+    0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
+    0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
+    0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
+    0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
+    0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
+    0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
+    0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
+    0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
+    0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
+    0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
+    0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
+    0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
+    0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
+    0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
+    0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
+    0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
+    0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
+    0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
+    0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
+    0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
+    0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
+    0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
+    0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
+    0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
+    0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
+    0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
+    0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
+    0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
+    0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
+    0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
+    0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
+    0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
+    0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
+    0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
+    0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
+    0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
+    0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
+    0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
+    0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
+    0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
+    0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
+    0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
+    0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
+    0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
+    0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
+    0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
+    0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
+    0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
+    0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
+    0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
+    0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
+    0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
+    0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
+    0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
+    0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
+    0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
+    0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
+    0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
+    0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
+    0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
+    0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
+    0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
+    0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
+    0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
+    0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
+    0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
+    0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
+    0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
+    0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
+    0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
+    0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
+    0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
+    0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
+    0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
+    0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
+    0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
+    0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
+    0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
+    0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
+    0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
+    0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
+    0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
+    0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
+    0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
+    0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
+    0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
+    0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
+    0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
+    0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
+    0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
+    0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
+    0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
+    0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
+    0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
+    0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
+    0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
+    0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
+    0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
+    0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
+    0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
+    0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
+    0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
+    0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
+    0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
+    0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
+    0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
+    0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
+    0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
+    0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
+    0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
+    0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
+    0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
+    0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
+    0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
+    0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
+    0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
+    0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
+    0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
+    0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
+    0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
+    0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
+    0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
+    0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
+    0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
+    0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
+    0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
+    0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
+    0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
+    0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
+    0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
+    0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
+    0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
+    0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
+    0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
+    0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
+    0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
+    0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
+    0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
+    0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
+    0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
+    0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
+    0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
+    0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
+    0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
+    0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
+    0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
+    0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
+    0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
+    0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
+    0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
+    0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
+    0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
+    0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
+    0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
+    0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
+    0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
+    0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
+    0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
+    0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
+    0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
+    0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
+    0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
+    0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
+    0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
+    0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
+    0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
+    0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
+    0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
+    0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
+    0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
+    0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
+    0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
+    0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
+    0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
+    0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
+    0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
+    0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
+    0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
+    0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
+    0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
+    0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
+    0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
+    0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
+    0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
+    0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
+    0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
+    0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
+    0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
+    0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
+    0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
+    0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
+    0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
+    0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
+    0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
+    0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
+    0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
+    0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
+    0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
+    0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
+    0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
+    0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
+    0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
+    0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
+    0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
+    0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
+    0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
+    0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
+    0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
+    0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
+    0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
+    0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
+    0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
+    0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
+    0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
+    0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
+    0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
+    0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
+    0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
+    0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
+    0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
+    0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
+    0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
+    0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
+    0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
+    0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
+    0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
+    0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
+    0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
+    0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
+    0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
+    0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
+    0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
+    0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
+    0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
+    0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
+    0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
+    0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
+    0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
+    0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
+    0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
+    0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
+    0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
+    0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
+    0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
+    0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
+    0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
+    0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
+    0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
+    0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
+    0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
+    0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
+    0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
+    0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
+    0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
+    0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
+    0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
+    0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
+    0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
+    0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
+    0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
+    0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
+    0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
+    0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
+    0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
+    0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
+    0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
+    0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
+    0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
+    0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
+    0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
+    0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
+    0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256)
+    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
+    0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
+    0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
+    0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
+    0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
+    0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
+    0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
+    0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
+    0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
+    0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
+    0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
+    0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
+    0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
+    0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
+    0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
+    0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
+    0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
+    0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
+    0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
+    0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
+    0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
+    0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
+    0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
+    0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
+    0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
+    0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
+    0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
+    0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
+    0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
+    0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
+    0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
+    0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
+    0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
+    0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
+    0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
+    0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
+    0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
+    0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
+    0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
+    0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
+    0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
+    0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
+    0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
+    0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
+    0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
+    0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
+    0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
+    0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
+    0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
+    0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
+    0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
+    0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
+    0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
+    0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
+    0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
+    0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
+    0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
+    0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
+    0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
+    0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
+    0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
+    0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
+    0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
+    0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
+    0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
+GGML_TABLE_END()
+
+#define NGRID_IQ1S 2048
+#define IQ1S_DELTA 0.125f
+#define IQ1M_DELTA 0.125f
+#if defined(GGML_COMMON_IMPL_C)
+GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
+    0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
+    0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff,
+    0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000,
+    0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000,
+    0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101,
+    0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff,
+    0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00,
+    0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101,
+    0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100,
+    0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00,
+    0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000,
+    0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000,
+    0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101,
+    0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff,
+    0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100,
+    0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01,
+    0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000,
+    0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff,
+    0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001,
+    0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000,
+    0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00,
+    0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff,
+    0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff,
+    0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000,
+    0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff,
+    0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff,
+    0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101,
+    0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff,
+    0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000,
+    0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100,
+    0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01,
+    0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00,
+    0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00,
+    0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01,
+    0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff,
+    0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000,
+    0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff,
+    0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000,
+    0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101,
+    0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100,
+    0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff,
+    0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff,
+    0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001,
+    0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01,
+    0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff,
+    0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000,
+    0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000,
+    0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff,
+    0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01,
+    0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00,
+    0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000,
+    0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000,
+    0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000,
+    0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001,
+    0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff,
+    0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01,
+    0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff,
+    0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff,
+    0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000,
+    0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000,
+    0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00,
+    0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001,
+    0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100,
+    0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101,
+    0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00,
+    0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00,
+    0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000,
+    0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001,
+    0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff,
+    0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000,
+    0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000,
+    0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff,
+    0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100,
+    0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000,
+    0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101,
+    0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001,
+    0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000,
+    0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff,
+    0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01,
+    0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100,
+    0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000,
+    0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01,
+    0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101,
+    0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000,
+    0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff,
+    0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff,
+    0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001,
+    0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001,
+    0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000,
+    0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000,
+    0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00,
+    0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100,
+    0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01,
+    0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00,
+    0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff,
+    0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000,
+    0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00,
+    0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000,
+    0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff,
+    0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00,
+    0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000,
+    0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000,
+    0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff,
+    0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00,
+    0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00,
+    0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001,
+    0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001,
+    0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000,
+    0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100,
+    0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101,
+    0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101,
+    0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000,
+    0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00,
+    0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff,
+    0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000,
+    0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff,
+    0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff,
+    0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff,
+    0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101,
+    0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001,
+    0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100,
+    0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101,
+    0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01,
+    0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00,
+    0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00,
+    0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000,
+    0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101,
+    0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100,
+    0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001,
+    0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff,
+    0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00,
+    0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000,
+    0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100,
+    0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00,
+    0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01,
+    0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff,
+    0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101,
+    0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100,
+    0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100,
+    0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000,
+    0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff,
+    0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff,
+    0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000,
+    0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101,
+    0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000,
+    0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101,
+    0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101,
+    0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100,
+    0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01,
+    0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001,
+    0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001,
+    0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff,
+    0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff,
+    0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000,
+    0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000,
+    0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101,
+    0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff,
+    0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001,
+    0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000,
+    0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff,
+    0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00,
+    0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff,
+    0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000,
+    0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00,
+    0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff,
+    0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000,
+    0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000,
+    0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff,
+    0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000,
+    0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000,
+    0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000,
+    0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101,
+    0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff,
+    0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100,
+    0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101,
+    0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001,
+    0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff,
+    0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100,
+    0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff,
+    0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01,
+    0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff,
+    0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff,
+    0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff,
+    0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff,
+    0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001,
+    0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff,
+    0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01,
+    0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00,
+    0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100,
+    0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101,
+    0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff,
+    0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00,
+    0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01,
+    0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00,
+    0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100,
+    0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00,
+    0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000,
+    0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000,
+    0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000,
+    0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000,
+    0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001,
+    0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff,
+    0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00,
+    0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff,
+    0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00,
+    0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000,
+    0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff,
+    0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff,
+    0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101,
+    0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000,
+    0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff,
+    0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff,
+    0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff,
+    0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff,
+    0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000,
+    0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000,
+    0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100,
+    0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00,
+    0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100,
+    0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100,
+    0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00,
+    0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff,
+    0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff,
+    0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100,
+    0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff,
+    0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000,
+    0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00,
+    0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+    0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00,
+    0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100,
+    0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff,
+    0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff,
+    0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001,
+    0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00,
+    0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101,
+    0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001,
+    0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000,
+    0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100,
+    0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000,
+    0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001,
+    0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000,
+    0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff,
+    0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101,
+    0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+    0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101,
+    0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000,
+    0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff,
+    0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000,
+    0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff,
+    0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01,
+    0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100,
+    0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff,
+    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101,
+    0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001,
+    0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01,
+    0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff,
+    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+    0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff,
+    0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00,
+    0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff,
+    0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff,
+    0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff,
+    0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001,
+    0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff,
+    0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff,
+    0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001,
+    0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff,
+    0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff,
+    0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000,
+    0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00,
+    0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100,
+    0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01,
+    0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff,
+    0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff,
+    0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000,
+    0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101,
+    0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01,
+    0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100,
+    0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100,
+    0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00,
+    0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101,
+    0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001,
+    0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01,
+    0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100,
+    0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff,
+    0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01,
+    0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff,
+    0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000,
+    0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001,
+    0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01,
+    0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff,
+    0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff,
+    0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00,
+    0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff,
+    0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100,
+    0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000,
+    0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001,
+    0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100,
+    0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff,
+    0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff,
+    0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01,
+    0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101,
+    0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001,
+    0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff,
+    0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101,
+    0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100,
+    0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00,
+    0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01,
+    0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001,
+    0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00,
+    0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000,
+    0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101,
+    0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001,
+    0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100,
+    0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000,
+    0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000,
+    0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00,
+    0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101,
+    0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff,
+    0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101,
+    0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001,
+    0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01,
+    0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100,
+    0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000,
+    0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00,
+    0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff,
+    0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001,
+    0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001,
+    0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff,
+    0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00,
+    0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100,
+    0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00,
+    0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101,
+    0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff,
+    0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff,
+    0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff,
+    0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00,
+    0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff,
+    0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff,
+    0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101,
+    0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100,
+    0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101,
+    0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00,
+    0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00,
+    0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff,
+    0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff,
+    0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001,
+    0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000,
+    0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff,
+    0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff,
+    0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff,
+    0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff,
+    0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001,
+    0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100,
+    0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101,
+    0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001,
+    0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001,
+    0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101,
+    0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101,
+    0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff,
+    0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff,
+    0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000,
+    0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101,
+    0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001,
+    0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff,
+    0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000,
+    0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff,
+    0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100,
+    0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000,
+    0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101,
+    0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff,
+    0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100,
+    0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff,
+    0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01,
+    0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100,
+    0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000,
+    0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100,
+    0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100,
+    0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001,
+    0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000,
+    0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000,
+    0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000,
+    0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00,
+    0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff,
+    0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001,
+    0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000,
+    0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101,
+    0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101,
+    0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000,
+    0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff,
+    0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000,
+    0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101,
+    0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff,
+    0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff,
+    0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000,
+    0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101,
+    0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff,
+    0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff,
+    0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01,
+    0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff,
+    0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000,
+    0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101,
+    0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff,
+    0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff,
+    0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff,
+    0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01,
+    0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00,
+    0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000,
+    0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000,
+    0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101,
+    0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001,
+    0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff,
+    0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000,
+    0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff,
+    0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000,
+    0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000,
+    0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000,
+    0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000,
+    0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001,
+    0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff,
+    0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001,
+    0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000,
+    0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000,
+    0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00,
+    0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000,
+    0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101,
+    0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001,
+    0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00,
+    0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001,
+    0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff,
+    0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000,
+    0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00,
+    0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100,
+    0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101,
+    0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001,
+    0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01,
+    0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff,
+    0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff,
+    0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00,
+    0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01,
+    0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100,
+    0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000,
+    0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff,
+    0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff,
+    0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff,
+    0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000,
+    0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff,
+    0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101,
+    0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff,
+    0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001,
+    0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001,
+    0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001,
+    0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000,
+    0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000,
+    0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff,
+    0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101,
+    0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001,
+    0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff,
+    0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000,
+    0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000,
+    0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff,
+    0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101,
+    0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000,
+    0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100,
+    0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00,
+    0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000,
+    0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff,
+    0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01,
+    0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00,
+    0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff,
+    0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000,
+    0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101,
+    0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff,
+    0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001,
+    0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff,
+    0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000,
+    0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100,
+    0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff,
+    0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101,
+    0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100,
+    0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff,
+    0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01,
+    0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000,
+    0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff,
+    0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00,
+    0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100,
+    0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff,
+    0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001,
+    0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00,
+    0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100,
+    0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff,
+    0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01,
+    0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00,
+    0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000,
+    0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01,
+    0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff,
+    0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000,
+    0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff,
+    0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff,
+    0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff,
+    0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001,
+    0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff,
+    0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01,
+    0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff,
+    0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00,
+    0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00,
+    0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001,
+    0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101,
+    0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101,
+    0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff,
+    0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000,
+    0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101,
+GGML_TABLE_END()
+#else
+GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)
+    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
+    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
+    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
+    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
+    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
+    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
+    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
+    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
+    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
+    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
+    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
+    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
+    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
+    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
+    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
+    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
+    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
+    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
+    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
+    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
+    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
+    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
+    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
+    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
+    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
+    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
+    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
+    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
+    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
+    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
+    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
+    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
+    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
+    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
+    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
+    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
+    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
+    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
+    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
+    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
+    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
+    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
+    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
+    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
+    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
+    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
+    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
+    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
+    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
+    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
+    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
+    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
+    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
+    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
+    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
+    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
+    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
+    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
+    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
+    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
+    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
+    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
+    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
+    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
+    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
+    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
+    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
+    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
+    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
+    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
+    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
+    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
+    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
+    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
+    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
+    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
+    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
+    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
+    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
+    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
+    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
+    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
+    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
+    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
+    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
+    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
+    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
+    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
+    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
+    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
+    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
+    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
+    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
+    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
+    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
+    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
+    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
+    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
+    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
+    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
+    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
+    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
+    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
+    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
+    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
+    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
+    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
+    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
+    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
+    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
+    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
+    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
+    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
+    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
+    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
+    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
+    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
+    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
+    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
+    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
+    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
+    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
+    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
+    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
+    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
+    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
+    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
+    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
+    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
+    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
+    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
+    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
+    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
+    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
+    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
+    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
+    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
+    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
+    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
+    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
+    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
+    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
+    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
+    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
+    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
+    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
+    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
+    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
+    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
+    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
+    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
+    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
+    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
+    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
+    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
+    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
+    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
+    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
+    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
+    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
+    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
+    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
+    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
+    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
+    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
+    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
+    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
+    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
+    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
+    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
+    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
+    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
+    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
+    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
+    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
+    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
+    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
+    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
+    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
+    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
+    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
+    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
+    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
+    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
+    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
+    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
+    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
+    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
+    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
+    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
+    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
+    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
+    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
+    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
+    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
+    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
+    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
+    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
+    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
+    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
+    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
+    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
+    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
+    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
+    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
+    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
+    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
+    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
+    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
+    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
+    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
+    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
+    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
+    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
+    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
+    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
+    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
+    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
+    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
+    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
+    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
+    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
+    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
+    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
+    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
+    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
+    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
+    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
+    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
+    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
+    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
+    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
+    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
+    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
+    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
+    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
+    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
+    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
+    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
+    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
+    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
+    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
+    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
+    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
+    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
+    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
+    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
+    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
+    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
+    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
+    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
+    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
+    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
+    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
+    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
+    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
+GGML_TABLE_END()
+#endif
+
+#endif // GGML_COMMON_IMPL
+#endif // GGML_COMMON_IMPL
diff --git a/llama/ggml-cpp.h b/llama/ggml-cpp.h
new file mode 100644
index 000000000..ceb54875d
--- /dev/null
+++ b/llama/ggml-cpp.h
@@ -0,0 +1,64 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#ifndef __cplusplus
+#error "This header is for C++ only"
+#endif
+
+#include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include 
+
+// Smart pointers for ggml types
+
+// ggml
+
+struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } };
+struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } };
+
+typedef std::unique_ptr ggml_context_ptr;
+typedef std::unique_ptr gguf_context_ptr;
+
+// ggml-alloc
+
+struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
+
+typedef std::unique_ptr ggml_gallocr_ptr;
+
+// ggml-backend
+
+struct ggml_backend_deleter        { void operator()(ggml_backend_t backend)       { ggml_backend_free(backend); } };
+struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } };
+struct ggml_backend_event_deleter  { void operator()(ggml_backend_event_t event)   { ggml_backend_event_free(event); } };
+struct ggml_backend_sched_deleter  { void operator()(ggml_backend_sched_t sched)   { ggml_backend_sched_free(sched); } };
+
+typedef std::unique_ptr        ggml_backend_ptr;
+typedef std::unique_ptr ggml_backend_buffer_ptr;
+typedef std::unique_ptr  ggml_backend_event_ptr;
+typedef std::unique_ptr  ggml_backend_sched_ptr;
diff --git a/llama/ggml-cpu-aarch64.cpp b/llama/ggml-cpu-aarch64.cpp
new file mode 100644
index 000000000..0989fb203
--- /dev/null
+++ b/llama/ggml-cpu-aarch64.cpp
@@ -0,0 +1,4271 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_IMPL_CPP
+#define GGML_COMMON_DECL_CPP
+#include "ggml-common.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-cpu-traits.h"
+
+#include 
+#include 
+#include 
+#include 
+#include  // for qsort
+#include   // for GGML_ASSERT
+
+#include "ggml-cpu-aarch64.h"
+
+// TODO: move to include file?
+template  constexpr int QK_0() {
+    if constexpr (K == 4) {
+        return QK4_0;
+    }
+    if constexpr (K == 8) {
+        return QK8_0;
+    }
+    return -1;
+}
+
+template  struct block {
+    ggml_half d[N];                         // deltas for N qK_0 blocks
+    int8_t    qs[(QK_0() * N * K) / 8];  // quants for N qK_0 blocks
+};
+
+// control size
+static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding");
+static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding");
+static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding");
+static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding");
+
+using block_q4_0x4 = block<4, 4>;
+using block_q4_0x8 = block<4, 8>;
+using block_q8_0x4 = block<8, 4>;
+using block_q8_0x8 = block<8, 8>;
+
+struct block_iq4_nlx4 {
+    ggml_half d[4];            // deltas for 4 iq4_nl blocks
+    uint8_t   qs[QK4_NL * 2];  // nibbles / quants for 4 iq4_nl blocks
+};
+
+static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#elif defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// Functions to create the interleaved data layout formats
+
+// interleave 4 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x4
+// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
+// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
+//
+// - in                  : an array of block_q4_0 pointers
+// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
+//                         blck_size_interleave bytes
+// - xor_mask            : the mask to convert the nibbles in block_q4_0 quants bytes
+//                         from bias offset form to pure sign form (this saves subtract
+//                         operations durin unpacking)
+//
+#if defined(__AVX__)
+#if defined(__F16C__)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y)     _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x))))
+#define GGML_F32Cx16_REPEAT_LOAD(x)  _mm512_cvtph_ps(_mm256_set_m128i(x, x))
+#endif
+// the  _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
+#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68))
+#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask))
+#else
+#if defined(__AVX512F__)
+static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) {
+    float tmp[16];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i + 8] = GGML_FP16_TO_FP32(y[i]);
+    }
+
+    return _mm512_loadu_ps(tmp);
+}
+static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) {
+    float tmp[16];
+    uint16_t tmphalf[8];
+    _mm_storeu_si128((__m128i*)tmphalf, x);
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
+        tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]);
+        tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]);
+        tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]);
+    }
+
+    return _mm512_loadu_ps(tmp);
+}
+#endif
+static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) {
+    float tmp[8];
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+        tmp[i + 4] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) {
+    uint16_t tmphalf[8];
+    float tmp[8];
+
+    _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask));
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask)     __avx_repeat_f32cx8_load(x)
+#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask)     __avx_rearranged_f32cx8_load(x, arrangeMask)
+#if defined(__AVX512F__)
+#define GGML_F32Cx8x2_LOAD(x, y)     __avx512_f32cx8x2_load(x, y)
+#define GGML_F32Cx16_REPEAT_LOAD(x)  __avx512_repeat_f32cx16_load(x)
+#endif
+#endif
+#endif
+
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+#if defined(__AVX512F__)
+// add int16_t pairwise and return as 512 bit int vector
+static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
+    const __m512i ones = _mm512_set1_epi16(1);
+    return _mm512_madd_epi16(ones, x);
+}
+
+static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) {
+#if defined(__AVX512VNNI__)
+    const __m512i zero = _mm512_setzero_si512();
+    return _mm512_dpbusd_epi32(zero, ax, sy);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m512i dot = _mm512_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_int_32x16(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as 512 bit int vector
+static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
+    const __m512i zero = _mm512_setzero_si512();
+    // Get absolute values of x vectors
+    const __m512i ax = _mm512_abs_epi8(x);
+    // Sign the values of the y vectors
+    __mmask64 blt0 = _mm512_movepi8_mask(x);
+    const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y);
+    return mul_sum_us8_pairs_int32x16(ax, sy);
+}
+#endif
+
+// add int16_t pairwise and return as 256 bit int vector
+static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    return _mm256_madd_epi16(ones, x);
+}
+
+static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+    const __m256i zero = _mm256_setzero_si256();
+    return _mm256_dpbusd_epi32(zero, ax, sy);
+#elif defined(__AVXVNNI__)
+    const __m256i zero = _mm256_setzero_si256();
+    return _mm256_dpbusd_avx_epi32(zero, ax, sy);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_int32x8(dot);
+#endif
+}
+
+// Integer variant of the function defined in ggml-quants.c
+// multiply int8_t, add results pairwise twice and return as 256 bit int vector
+static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    return _mm256_dpbssd_epi32(zero, x, y);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_int32x8(ax, sy);
+#endif
+}
+#endif
+
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 8; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 4;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t srcv[4][8];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j);
+            for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
+
+            for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
+            for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
+            for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
+
+            const float amax = vmaxvq_f32(amaxv[0]);
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < 4; j++) {
+            float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]);
+            int32x4_t vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[1][2 * j], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[2][2 * j], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3);
+
+            v = vmulq_n_f32(srcv[3][2 * j], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3);
+            v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]);
+            vi = vcvtnq_s32_f32(v);
+            y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0);
+            y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1);
+            y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2);
+            y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    float id[4];
+    __m256 srcv[4][4];
+    __m256 idvec[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            // Load elements into 4 AVX vectors
+            __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 );
+            __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 );
+            __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 );
+            __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 );
+
+            // Compute max(abs(e)) for the block
+            const __m256 signBit = _mm256_set1_ps( -0.0f );
+            __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+            maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+            __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+            max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+            max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+            const float maxScalar = _mm_cvtss_f32( max4 );
+
+            // Divided by 127.f to mirror results in quantize_row_q8_0
+            const float d = maxScalar  / 127.f;
+            id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f;
+
+            // Store the scale for the individual block
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+
+            // Store the values in blocks of eight values - Aim is to use these later for block interleaving
+            srcv[row_iter][0] = v0;
+            srcv[row_iter][1] = v1;
+            srcv[row_iter][2] = v2;
+            srcv[row_iter][3] = v3;
+            idvec[row_iter] = _mm256_set1_ps(id[row_iter]);
+        }
+
+        // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved
+        for (int j = 0; j < 4; j++) {
+            // Apply the multiplier
+            __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]);
+            __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]);
+            __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]);
+            __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]);
+
+            // Round to nearest integer
+            v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+            v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+            v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+            v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+            // Convert floats to integers
+            __m256i i0 = _mm256_cvtps_epi32( v0 );
+            __m256i i1 = _mm256_cvtps_epi32( v1 );
+            __m256i i2 = _mm256_cvtps_epi32( v2 );
+            __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+            // Convert int32 to int16
+            i0 = _mm256_packs_epi32( i0, i1 );
+            i2 = _mm256_packs_epi32( i2, i3 );
+            // Convert int16 to int8
+            i0 = _mm256_packs_epi16( i0, i2 );
+
+            //  Permute and store the quantized weights in the required order after the pack instruction
+            const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+            i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+            _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
+#else
+            // Since we don't have in AVX some necessary functions,
+            // we split the registers in half and call AVX2 analogs from SSE
+            __m128i ni0 = _mm256_castsi256_si128( i0 );
+            __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+            __m128i ni2 = _mm256_castsi256_si128( i1 );
+            __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+            __m128i ni4 = _mm256_castsi256_si128( i2 );
+            __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+            __m128i ni6 = _mm256_castsi256_si128( i3 );
+            __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+            // Convert int32 to int16
+            ni0 = _mm_packs_epi32( ni0, ni1 );
+            ni2 = _mm_packs_epi32( ni2, ni3 );
+            ni4 = _mm_packs_epi32( ni4, ni5 );
+            ni6 = _mm_packs_epi32( ni6, ni7 );
+            // Convert int16 to int8
+            ni0 = _mm_packs_epi16( ni0, ni2 );
+            ni4 = _mm_packs_epi16( ni4, ni6 );
+            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0);
+            _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4);
+#endif
+        }
+    }
+#else
+    // scalar
+    const int blck_size_interleave = 8;
+    float srcv[4][QK8_0];
+    float id[4];
+
+    for (int i = 0; i < nb; i++) {
+        for (int row_iter = 0; row_iter < 4; row_iter++) {
+            float amax = 0.0f; // absolute max
+
+            for (int j = 0; j < QK8_0; j++) {
+                srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
+                amax = MAX(amax, fabsf(srcv[row_iter][j]));
+            }
+
+            const float d = amax / ((1 << 7) - 1);
+            id[row_iter] = d ? 1.0f / d : 0.0f;
+
+            y[i].d[row_iter] = GGML_FP32_TO_FP16(d);
+        }
+
+        for (int j = 0; j < QK8_0 * 4; j++) {
+            int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
+            int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
+            src_offset += (j % blck_size_interleave);
+
+            float x0 = srcv[src_id][src_offset] * id[src_id];
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+#endif
+}
+
+static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
+    assert(nrow == 4);
+    UNUSED(nrow);
+    if (blck_size_interleave == 4) {
+        quantize_q8_0_4x4(x, vy, n_per_row);
+    } else if (blck_size_interleave == 8) {
+        quantize_q8_0_4x8(x, vy, n_per_row);
+    } else {
+        assert(false);
+    }
+}
+
+static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
+
+        for (int c = 0; c < nc; c += ncols_interleaved) {
+            const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+            float32x4_t acc = vdupq_n_f32(0);
+            for (int b = 0; b < nb; b++) {
+                int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
+                int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
+                int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
+                int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
+                float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
+
+                int8x16_t a0 = vld1q_s8(a_ptr->qs);
+                int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2);
+                float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+                int32x4_t ret = vdupq_n_s32(0);
+
+                ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0);
+                ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1);
+                ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2);
+                ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3);
+
+                ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0);
+                ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1);
+                ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2);
+                ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3);
+
+                acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
+                                vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+                a_ptr++;
+                b_ptr++;
+            }
+            vst1q_f32(s, acc);
+            s += ncols_interleaved;
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx;
+
+        for (int c = 0; c < nc; c += ncols_interleaved) {
+            const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+            float32x4_t acc = vdupq_n_f32(0);
+            for (int b = 0; b < nb; b++) {
+                int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs);
+                int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16);
+                int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32);
+                int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48);
+                float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
+
+                int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs);
+                int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1);
+                int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2);
+                int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3);
+                float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+                int32x4_t ret0 = vdupq_n_s32(0);
+                int32x4_t ret1 = vdupq_n_s32(0);
+
+                ret0 = vdotq_s32(ret0, b0 << 4, a0);
+                ret1 = vdotq_s32(ret1, b1 << 4, a0);
+                ret0 = vdotq_s32(ret0, b2 << 4, a1);
+                ret1 = vdotq_s32(ret1, b3 << 4, a1);
+
+                ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2);
+                ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2);
+                ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3);
+                ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3);
+
+                int32x4_t ret = vpaddq_s32(ret0, ret1);
+
+                acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4),
+                        vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+                a_ptr++;
+                b_ptr++;
+            }
+            vst1q_f32(s, acc);
+            s += ncols_interleaved;
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                        const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                    }
+                    sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE)
+    if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+
+        __asm__ __volatile__(
+            "ptrue p0.b\n"
+            "add %x[b_ptr], %x[b_ptr], #0x10\n"
+            "1:"  // Column loop
+            "add x22, %x[a_ptr], #0x2\n"
+            "mov z31.b, #0x0\n"
+            "mov x21, %x[nb]\n"
+            "2:"  // Block loop
+            "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n"
+            "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n"
+            "mov z28.s, #0x0\n"
+            "mov z27.s, #0x0\n"
+            "ld1rd { z26.d }, p0/Z, [x22]\n"
+            "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n"
+            "sub x20, x22, #0x2\n"
+            "sub x21, x21, #0x1\n"
+            "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n"
+            "ld1rd { z23.d }, p0/Z, [x22, #8]\n"
+            "lsl z22.b, z30.b, #0x4\n"
+            "lsl z16.b, z29.b, #0x4\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1rd { z21.d }, p0/Z, [x22, #16]\n"
+            "ld1rd { z20.d }, p0/Z, [x22, #24]\n"
+            "lsl z19.b, z25.b, #0x4\n"
+            "and z25.b, z25.b, #0xf0\n"
+            "ld1rh { z17.h }, p0/Z, [x20]\n"
+            "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n"
+            "sdot z28.s, z22.b, z26.b\n"
+            "sdot z27.s, z16.b, z26.b\n"
+            "lsl z16.b, z24.b, #0x4\n"
+            "add x22, x22, #0x22\n"
+            "and z24.b, z24.b, #0xf0\n"
+            "add %x[b_ptr], %x[b_ptr], #0x90\n"
+            "fcvt z17.s, p0/m, z17.h\n"
+            "fcvt z18.s, p0/m, z18.h\n"
+            "sdot z28.s, z19.b, z23.b\n"
+            "sdot z27.s, z16.b, z23.b\n"
+            "fmul z18.s, z18.s, z17.s\n"
+            "sdot z28.s, z30.b, z21.b\n"
+            "sdot z27.s, z29.b, z21.b\n"
+            "sdot z28.s, z25.b, z20.b\n"
+            "sdot z27.s, z24.b, z20.b\n"
+            "uzp1 z17.s, z28.s, z27.s\n"
+            "uzp2 z16.s, z28.s, z27.s\n"
+            "add z17.s, z17.s, z16.s\n"
+            "asr z17.s, z17.s, #0x4\n"
+            "scvtf z17.s, p0/m, z17.s\n"
+            "fmla z31.s, p0/M, z17.s, z18.s\n"
+            "cbnz x21, 2b\n"
+            "sub %x[nc], %x[nc], #0x8\n"
+            "st1w { z31.s }, p0, [%x[res_ptr]]\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "cbnz %x[nc], 1b\n"
+            : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc)
+            : [a_ptr] "r" (a_ptr), [nb] "r" (nb)
+            : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+#endif // #if defined(__ARM_FEATURE_SVE)
+#elif defined(__AVX2__)
+    // Lookup table to convert signed nibbles to signed bytes
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+    __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
+    __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
+
+    // Permute mask used for easier vector processing at later stages
+    const __m256i m4b = _mm256_set1_epi8(0x0F);
+
+    int64_t b_nb = n / QK4_0;
+
+    const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+    const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy;
+
+    // Process Q8_0 blocks one by one
+    for (int64_t y = 0; y < nr; y++) {
+
+        // Pointers to LHS blocks of block_q8_0 format
+        const block_q8_0 * a_ptr = a_ptr_start + (y * nb);
+
+        // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+        for (int64_t x = 0; x < nc / 8; x++) {
+
+            // Pointers to RHS blocks
+            const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+            // Master FP accumulator
+            __m256 acc_row = _mm256_setzero_ps();
+
+            for (int64_t b = 0; b < nb; b++) {
+                // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7)
+                const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1);
+                const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2);
+                const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3);
+
+                // 4-bit -> 8-bit - Sign is maintained
+                const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7)
+                const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7)
+                const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
+                const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15)
+
+                const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23)
+                const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23)
+                const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31)
+                const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31)
+
+                // Load the scale values for the 8 blocks interleaved in block_q4_0x8
+                const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+
+                // Load and convert to FP32 scale from block_q8_0
+                const __m256 row_scale_f32 = _mm256_set1_ps(GGML_FP16_TO_FP32(a_ptr[b].d));
+
+                // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector
+                __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs));
+                __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16)));
+
+                lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15)
+                lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31))
+
+                __m256i iacc = _mm256_setzero_si256();
+
+                // Dot product done within 32 bit lanes and accumulated in the same vector
+                // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
+                // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
+                // ...........................................................................
+                // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)));
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)));
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)));
+
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)));
+                iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)));
+
+                // Accumulated values multipled with appropriate scales
+                acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row);
+            }
+
+            // Accumulated output values permuted so as to be stored in appropriate order post accumulation
+            acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask);
+            _mm256_storeu_ps(s + (y * nr + x * 8), acc_row);
+        }
+    }
+    return;
+#elif defined(__riscv_v_intrinsic)
+    if (__riscv_vlenb() >= QK4_0) {
+        const size_t vl = QK4_0;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+            vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+            for (int l = 0; l < nb; l++) {
+                const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0];
+                const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8];
+                const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16];
+                const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24];
+                __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4));
+                const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4));
+                const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4));
+                const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4));
+
+                const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
+                const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
+                const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
+                const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
+                const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
+                const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
+                const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
+
+                const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m));
+                const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                // vector version needs Zvfhmin extension
+                const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
+                const float b_scales[8] = {
+                    GGML_FP16_TO_FP32(b_ptr[l].d[0]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[1]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[2]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[3]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[4]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[5]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[6]),
+                    GGML_FP16_TO_FP32(b_ptr[l].d[7])
+                };
+                const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
+                const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
+                sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4);
+            }
+            __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+    {
+        float sumf[8];
+        int sumi;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+
+            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int j = 0; j < ncols_interleaved; j++) {
+                        sumi = 0;
+                        for (int i = 0; i < blocklen; ++i) {
+                            const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                            const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
+                        }
+                        sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                    }
+                }
+            }
+            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        float * res_ptr = s;
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+            float32x4_t sumf = vdupq_n_f32(0);
+            for (int l = 0; l < nb; l++) {
+                uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
+                uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
+                uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
+                uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
+
+                int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
+                int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
+                int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
+                int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
+                int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
+                int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
+                int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
+                int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
+
+                int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
+                int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
+
+                int32x4_t sumi = vdupq_n_s32(0);
+                sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
+                sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
+                sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
+                sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
+                sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
+                sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
+                sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
+                sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
+
+                float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
+                float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+                float32x4_t d = a_d * b_d;
+
+                sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
+            }
+
+            vst1q_f32(res_ptr + x * 4, sumf);
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4];
+        int sumi;
+
+        const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+            for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int j = 0; j < ncols_interleaved; j++) {
+                        sumi = 0;
+                        for (int i = 0; i < blocklen; ++i) {
+                            const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                            const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                            sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                        }
+                        sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
+                    }
+                }
+            }
+            for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x10, %x[nr]\n"
+            "mov x9, #0x88\n"
+            "cmp x10, #0x10\n"
+            "mul x9, %x[nb], x9\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x28, %x[b_ptr], #0x8\n"
+            "mov x27, %x[nc]\n"
+            "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x25, %x[a_ptr], #0x8\n"
+            "movi v15.16b, #0x0\n"
+            "movi v19.16b, #0x0\n"
+            "mov x24, %x[nb]\n"
+            "add x23, x25, x9\n"
+            "movi v18.16b, #0x0\n"
+            "movi v14.16b, #0x0\n"
+            "add x22, x23, x9\n"
+            "movi v11.16b, #0x0\n"
+            "movi v13.16b, #0x0\n"
+            "add x21, x22, x9\n"
+            "movi v23.16b, #0x0\n"
+            "movi v16.16b, #0x0\n"
+            "movi v25.16b, #0x0\n"
+            "movi v7.16b, #0x0\n"
+            "movi v0.16b, #0x0\n"
+            "movi v4.16b, #0x0\n"
+            "movi v5.16b, #0x0\n"
+            "movi v21.16b, #0x0\n"
+            "movi v8.16b, #0x0\n"
+            "movi v1.16b, #0x0\n"
+            "3:"  // Block loop
+            "ldr q3, [x28, #0x0]\n"
+            "ldr q31, [x25, #0x0]\n"
+            "movi v28.16b, #0x4\n"
+            "movi v10.4s, #0x0\n"
+            "ldr q22, [x28, #0x10]\n"
+            "ldr q6, [x25, #0x10]\n"
+            "movi v29.4s, #0x0\n"
+            "movi v9.4s, #0x0\n"
+            "ldr q27, [x28, #0x20]\n"
+            "ldr q30, [x28, #0x30]\n"
+            "movi v20.4s, #0x0\n"
+            "movi v24.16b, #0xf0\n"
+            "ldr d2, [x25, #-0x8]\n"
+            "ldr d26, [x23, #-0x8]\n"
+            "sshl v12.16b, v3.16b, v28.16b\n"
+            "sub x20, x28, #0x8\n"
+            "ldr d17, [x20, #0x0]\n"
+            "and v3.16b, v3.16b, v24.16b\n"
+            "subs x24, x24, #0x1\n"
+            "add x28, x28, #0x48\n"
+            ".inst 0x4f9fe18a  // sdot v10.4s, v12.16b, v31.4b[0]\n"
+            ".inst 0x4fbfe19d  // sdot v29.4s, v12.16b, v31.4b[1]\n"
+            ".inst 0x4f9fe989  // sdot v9.4s, v12.16b, v31.4b[2]\n"
+            ".inst 0x4fbfe994  // sdot v20.4s, v12.16b, v31.4b[3]\n"
+            "sshl v31.16b, v22.16b, v28.16b\n"
+            "and v22.16b, v22.16b, v24.16b\n"
+            "fcvtl v17.4s, v17.4h\n"
+            "fcvtl v2.4s, v2.4h\n"
+            "fcvtl v26.4s, v26.4h\n"
+            ".inst 0x4f86e3ea  // sdot v10.4s, v31.16b, v6.4b[0]\n"
+            ".inst 0x4fa6e3fd  // sdot v29.4s, v31.16b, v6.4b[1]\n"
+            ".inst 0x4f86ebe9  // sdot v9.4s, v31.16b, v6.4b[2]\n"
+            ".inst 0x4fa6ebf4  // sdot v20.4s, v31.16b, v6.4b[3]\n"
+            "sshl v6.16b, v27.16b, v28.16b\n"
+            "sshl v28.16b, v30.16b, v28.16b\n"
+            "and v27.16b, v27.16b, v24.16b\n"
+            "and v30.16b, v30.16b, v24.16b\n"
+            "ldr q24, [x25, #0x20]\n"
+            ".inst 0x4f98e0ca  // sdot v10.4s, v6.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+            ".inst 0x4f98e8c9  // sdot v9.4s, v6.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e8d4  // sdot v20.4s, v6.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x30]\n"
+            ".inst 0x4f98e38a  // sdot v10.4s, v28.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e39d  // sdot v29.4s, v28.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb89  // sdot v9.4s, v28.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb94  // sdot v20.4s, v28.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x40]\n"
+            ".inst 0x4f98e06a  // sdot v10.4s, v3.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+            ".inst 0x4f98e869  // sdot v9.4s, v3.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e874  // sdot v20.4s, v3.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x50]\n"
+            ".inst 0x4f98e2ca  // sdot v10.4s, v22.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e2dd  // sdot v29.4s, v22.16b, v24.4b[1]\n"
+            ".inst 0x4f98eac9  // sdot v9.4s, v22.16b, v24.4b[2]\n"
+            ".inst 0x4fb8ead4  // sdot v20.4s, v22.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x60]\n"
+            ".inst 0x4f98e36a  // sdot v10.4s, v27.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb69  // sdot v9.4s, v27.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb74  // sdot v20.4s, v27.16b, v24.4b[3]\n"
+            "ldr q24, [x25, #0x70]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x4f98e3ca  // sdot v10.4s, v30.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e3dd  // sdot v29.4s, v30.16b, v24.4b[1]\n"
+            ".inst 0x4f98ebc9  // sdot v9.4s, v30.16b, v24.4b[2]\n"
+            ".inst 0x4fb8ebd4  // sdot v20.4s, v30.16b, v24.4b[3]\n"
+            "fmul v24.4s, v17.4s, v2.s[0]\n"
+            "scvtf v10.4s, v10.4s, #0x4\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "scvtf v9.4s, v9.4s, #0x4\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v15.4s, v10.4s, v24.4s\n"
+            "ldr q24, [x23, #0x0]\n"
+            "fmul v10.4s, v17.4s, v2.s[1]\n"
+            "fmla v19.4s, v29.4s, v10.4s\n"
+            "ldr q10, [x23, #0x10]\n"
+            "fmul v29.4s, v17.4s, v2.s[2]\n"
+            "fmul v2.4s, v17.4s, v2.s[3]\n"
+            "fmla v18.4s, v9.4s, v29.4s\n"
+            "movi v9.4s, #0x0\n"
+            "movi v29.4s, #0x0\n"
+            ".inst 0x4f98e189  // sdot v9.4s, v12.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e19d  // sdot v29.4s, v12.16b, v24.4b[1]\n"
+            "fmla v14.4s, v20.4s, v2.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v2.4s, #0x0\n"
+            ".inst 0x4f98e994  // sdot v20.4s, v12.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+            "ldr q24, [x23, #0x20]\n"
+            ".inst 0x4f8ae3e9  // sdot v9.4s, v31.16b, v10.4b[0]\n"
+            ".inst 0x4faae3fd  // sdot v29.4s, v31.16b, v10.4b[1]\n"
+            ".inst 0x4f8aebf4  // sdot v20.4s, v31.16b, v10.4b[2]\n"
+            ".inst 0x4faaebe2  // sdot v2.4s, v31.16b, v10.4b[3]\n"
+            "ldr q10, [x23, #0x30]\n"
+            ".inst 0x4f98e0c9  // sdot v9.4s, v6.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e0dd  // sdot v29.4s, v6.16b, v24.4b[1]\n"
+            ".inst 0x4f98e8d4  // sdot v20.4s, v6.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+            "ldr q24, [x23, #0x40]\n"
+            ".inst 0x4f8ae389  // sdot v9.4s, v28.16b, v10.4b[0]\n"
+            ".inst 0x4faae39d  // sdot v29.4s, v28.16b, v10.4b[1]\n"
+            ".inst 0x4f8aeb94  // sdot v20.4s, v28.16b, v10.4b[2]\n"
+            ".inst 0x4faaeb82  // sdot v2.4s, v28.16b, v10.4b[3]\n"
+            "ldr q10, [x23, #0x50]\n"
+            ".inst 0x4f98e069  // sdot v9.4s, v3.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e07d  // sdot v29.4s, v3.16b, v24.4b[1]\n"
+            ".inst 0x4f98e874  // sdot v20.4s, v3.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+            "ldr q24, [x23, #0x60]\n"
+            ".inst 0x4f8ae2c9  // sdot v9.4s, v22.16b, v10.4b[0]\n"
+            ".inst 0x4faae2dd  // sdot v29.4s, v22.16b, v10.4b[1]\n"
+            ".inst 0x4f8aead4  // sdot v20.4s, v22.16b, v10.4b[2]\n"
+            ".inst 0x4faaeac2  // sdot v2.4s, v22.16b, v10.4b[3]\n"
+            "ldr q10, [x23, #0x70]\n"
+            "add x23, x23, #0x88\n"
+            ".inst 0x4f98e369  // sdot v9.4s, v27.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e37d  // sdot v29.4s, v27.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb74  // sdot v20.4s, v27.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x0]\n"
+            ".inst 0x4f8ae3c9  // sdot v9.4s, v30.16b, v10.4b[0]\n"
+            ".inst 0x4faae3dd  // sdot v29.4s, v30.16b, v10.4b[1]\n"
+            ".inst 0x4f8aebd4  // sdot v20.4s, v30.16b, v10.4b[2]\n"
+            ".inst 0x4faaebc2  // sdot v2.4s, v30.16b, v10.4b[3]\n"
+            "fmul v10.4s, v17.4s, v26.s[0]\n"
+            "scvtf v9.4s, v9.4s, #0x4\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "scvtf v2.4s, v2.4s, #0x4\n"
+            "fmla v11.4s, v9.4s, v10.4s\n"
+            "ldr q9, [x22, #0x10]\n"
+            "fmul v10.4s, v17.4s, v26.s[1]\n"
+            "fmla v13.4s, v29.4s, v10.4s\n"
+            "ldr d29, [x22, #-0x8]\n"
+            "fmul v10.4s, v17.4s, v26.s[2]\n"
+            "fmul v26.4s, v17.4s, v26.s[3]\n"
+            "fcvtl v29.4s, v29.4h\n"
+            "fmla v23.4s, v20.4s, v10.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v10.4s, #0x0\n"
+            "fmla v16.4s, v2.4s, v26.4s\n"
+            "movi v26.4s, #0x0\n"
+            "movi v2.4s, #0x0\n"
+            ".inst 0x4f98e194  // sdot v20.4s, v12.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+            ".inst 0x4f98e99a  // sdot v26.4s, v12.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e982  // sdot v2.4s, v12.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x20]\n"
+            ".inst 0x4f89e3f4  // sdot v20.4s, v31.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+            ".inst 0x4f89ebfa  // sdot v26.4s, v31.16b, v9.4b[2]\n"
+            ".inst 0x4fa9ebe2  // sdot v2.4s, v31.16b, v9.4b[3]\n"
+            "ldr q9, [x22, #0x30]\n"
+            ".inst 0x4f98e0d4  // sdot v20.4s, v6.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e0ca  // sdot v10.4s, v6.16b, v24.4b[1]\n"
+            ".inst 0x4f98e8da  // sdot v26.4s, v6.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e8c2  // sdot v2.4s, v6.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x40]\n"
+            ".inst 0x4f89e394  // sdot v20.4s, v28.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+            ".inst 0x4f89eb9a  // sdot v26.4s, v28.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eb82  // sdot v2.4s, v28.16b, v9.4b[3]\n"
+            "ldr q9, [x22, #0x50]\n"
+            ".inst 0x4f98e074  // sdot v20.4s, v3.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e06a  // sdot v10.4s, v3.16b, v24.4b[1]\n"
+            ".inst 0x4f98e87a  // sdot v26.4s, v3.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e862  // sdot v2.4s, v3.16b, v24.4b[3]\n"
+            "ldr q24, [x22, #0x60]\n"
+            ".inst 0x4f89e2d4  // sdot v20.4s, v22.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+            ".inst 0x4f89eada  // sdot v26.4s, v22.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eac2  // sdot v2.4s, v22.16b, v9.4b[3]\n"
+            "ldr q9, [x22, #0x70]\n"
+            "add x22, x22, #0x88\n"
+            ".inst 0x4f98e374  // sdot v20.4s, v27.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e36a  // sdot v10.4s, v27.16b, v24.4b[1]\n"
+            ".inst 0x4f98eb7a  // sdot v26.4s, v27.16b, v24.4b[2]\n"
+            ".inst 0x4fb8eb62  // sdot v2.4s, v27.16b, v24.4b[3]\n"
+            "ldr q24, [x21, #0x0]\n"
+            ".inst 0x4f89e3d4  // sdot v20.4s, v30.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e3ca  // sdot v10.4s, v30.16b, v9.4b[1]\n"
+            ".inst 0x4f89ebda  // sdot v26.4s, v30.16b, v9.4b[2]\n"
+            ".inst 0x4fa9ebc2  // sdot v2.4s, v30.16b, v9.4b[3]\n"
+            "fmul v9.4s, v17.4s, v29.s[0]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "scvtf v10.4s, v10.4s, #0x4\n"
+            "scvtf v26.4s, v26.4s, #0x4\n"
+            "scvtf v2.4s, v2.4s, #0x4\n"
+            "fmla v25.4s, v20.4s, v9.4s\n"
+            "ldr q9, [x21, #0x10]\n"
+            "fmul v20.4s, v17.4s, v29.s[1]\n"
+            "fmla v7.4s, v10.4s, v20.4s\n"
+            "ldr d20, [x21, #-0x8]\n"
+            "fmul v10.4s, v17.4s, v29.s[2]\n"
+            "fmul v29.4s, v17.4s, v29.s[3]\n"
+            "fcvtl v20.4s, v20.4h\n"
+            "fmla v0.4s, v26.4s, v10.4s\n"
+            "movi v26.4s, #0x0\n"
+            "movi v10.4s, #0x0\n"
+            "fmla v4.4s, v2.4s, v29.4s\n"
+            "movi v2.4s, #0x0\n"
+            "movi v29.4s, #0x0\n"
+            ".inst 0x4f98e19a  // sdot v26.4s, v12.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e18a  // sdot v10.4s, v12.16b, v24.4b[1]\n"
+            ".inst 0x4f98e982  // sdot v2.4s, v12.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e99d  // sdot v29.4s, v12.16b, v24.4b[3]\n"
+            "ldr q12, [x21, #0x20]\n"
+            "fmul v24.4s, v17.4s, v20.s[0]\n"
+            ".inst 0x4f89e3fa  // sdot v26.4s, v31.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e3ea  // sdot v10.4s, v31.16b, v9.4b[1]\n"
+            ".inst 0x4f89ebe2  // sdot v2.4s, v31.16b, v9.4b[2]\n"
+            ".inst 0x4fa9ebfd  // sdot v29.4s, v31.16b, v9.4b[3]\n"
+            "ldr q9, [x21, #0x30]\n"
+            "fmul v31.4s, v17.4s, v20.s[1]\n"
+            ".inst 0x4f8ce0da  // sdot v26.4s, v6.16b, v12.4b[0]\n"
+            ".inst 0x4face0ca  // sdot v10.4s, v6.16b, v12.4b[1]\n"
+            ".inst 0x4f8ce8c2  // sdot v2.4s, v6.16b, v12.4b[2]\n"
+            ".inst 0x4face8dd  // sdot v29.4s, v6.16b, v12.4b[3]\n"
+            "ldr q12, [x21, #0x40]\n"
+            "fmul v6.4s, v17.4s, v20.s[2]\n"
+            "fmul v20.4s, v17.4s, v20.s[3]\n"
+            ".inst 0x4f89e39a  // sdot v26.4s, v28.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e38a  // sdot v10.4s, v28.16b, v9.4b[1]\n"
+            ".inst 0x4f89eb82  // sdot v2.4s, v28.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eb9d  // sdot v29.4s, v28.16b, v9.4b[3]\n"
+            "ldr q9, [x21, #0x50]\n"
+            ".inst 0x4f8ce07a  // sdot v26.4s, v3.16b, v12.4b[0]\n"
+            ".inst 0x4face06a  // sdot v10.4s, v3.16b, v12.4b[1]\n"
+            ".inst 0x4f8ce862  // sdot v2.4s, v3.16b, v12.4b[2]\n"
+            ".inst 0x4face87d  // sdot v29.4s, v3.16b, v12.4b[3]\n"
+            "ldr q12, [x21, #0x60]\n"
+            ".inst 0x4f89e2da  // sdot v26.4s, v22.16b, v9.4b[0]\n"
+            ".inst 0x4fa9e2ca  // sdot v10.4s, v22.16b, v9.4b[1]\n"
+            ".inst 0x4f89eac2  // sdot v2.4s, v22.16b, v9.4b[2]\n"
+            ".inst 0x4fa9eadd  // sdot v29.4s, v22.16b, v9.4b[3]\n"
+            "ldr q17, [x21, #0x70]\n"
+            "add x21, x21, #0x88\n"
+            ".inst 0x4f8ce37a  // sdot v26.4s, v27.16b, v12.4b[0]\n"
+            ".inst 0x4face36a  // sdot v10.4s, v27.16b, v12.4b[1]\n"
+            ".inst 0x4f8ceb62  // sdot v2.4s, v27.16b, v12.4b[2]\n"
+            ".inst 0x4faceb7d  // sdot v29.4s, v27.16b, v12.4b[3]\n"
+            ".inst 0x4f91e3da  // sdot v26.4s, v30.16b, v17.4b[0]\n"
+            ".inst 0x4fb1e3ca  // sdot v10.4s, v30.16b, v17.4b[1]\n"
+            ".inst 0x4f91ebc2  // sdot v2.4s, v30.16b, v17.4b[2]\n"
+            ".inst 0x4fb1ebdd  // sdot v29.4s, v30.16b, v17.4b[3]\n"
+            "scvtf v26.4s, v26.4s, #0x4\n"
+            "scvtf v10.4s, v10.4s, #0x4\n"
+            "fmla v5.4s, v26.4s, v24.4s\n"
+            "scvtf v2.4s, v2.4s, #0x4\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "fmla v21.4s, v10.4s, v31.4s\n"
+            "fmla v8.4s, v2.4s, v6.4s\n"
+            "fmla v1.4s, v29.4s, v20.4s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x27, x27, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "str q15, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q19, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q18, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q14, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q11, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q13, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q23, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q16, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q25, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q7, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q0, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q4, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q5, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q21, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q8, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q1, [x20, #0x0]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x10, x10, #0x10\n"
+            "cmp x10, #0x10\n"
+            "mov %x[res_ptr], x26\n"
+            "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x10, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x24, %x[b_ptr], #0x8\n"
+            "mov x23, %x[nc]\n"
+            "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "movi v15.16b, #0x0\n"
+            "movi v19.16b, #0x0\n"
+            "add x25, %x[a_ptr], #0x8\n"
+            "mov x21, %x[nb]\n"
+            "movi v18.16b, #0x0\n"
+            "movi v14.16b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ldr q7, [x24, #0x0]\n"
+            "ldr q5, [x25, #0x0]\n"
+            "movi v9.16b, #0x4\n"
+            "movi v4.4s, #0x0\n"
+            "ldr q3, [x24, #0x10]\n"
+            "ldr q2, [x25, #0x10]\n"
+            "movi v1.4s, #0x0\n"
+            "movi v0.4s, #0x0\n"
+            "ldr q13, [x24, #0x20]\n"
+            "ldr q31, [x25, #0x20]\n"
+            "movi v30.4s, #0x0\n"
+            "movi v29.16b, #0xf0\n"
+            "ldr q28, [x24, #0x30]\n"
+            "ldr q27, [x25, #0x30]\n"
+            "sshl v20.16b, v7.16b, v9.16b\n"
+            "sub x20, x24, #0x8\n"
+            "ldr q26, [x25, #0x40]\n"
+            "ldr q25, [x25, #0x50]\n"
+            "sshl v17.16b, v3.16b, v9.16b\n"
+            "and v7.16b, v7.16b, v29.16b\n"
+            "ldr q24, [x25, #0x60]\n"
+            "ldr q16, [x25, #0x70]\n"
+            "sshl v22.16b, v13.16b, v9.16b\n"
+            "and v3.16b, v3.16b, v29.16b\n"
+            "ldr d21, [x20, #0x0]\n"
+            "ldr d12, [x25, #-0x8]\n"
+            ".inst 0x4f85e284  // sdot v4.4s, v20.16b, v5.4b[0]\n"
+            ".inst 0x4fa5e281  // sdot v1.4s, v20.16b, v5.4b[1]\n"
+            ".inst 0x4f85ea80  // sdot v0.4s, v20.16b, v5.4b[2]\n"
+            ".inst 0x4fa5ea9e  // sdot v30.4s, v20.16b, v5.4b[3]\n"
+            "sshl v9.16b, v28.16b, v9.16b\n"
+            "subs x21, x21, #0x1\n"
+            "and v13.16b, v13.16b, v29.16b\n"
+            "and v28.16b, v28.16b, v29.16b\n"
+            "add x25, x25, #0x88\n"
+            "add x24, x24, #0x48\n"
+            "fcvtl v21.4s, v21.4h\n"
+            "fcvtl v12.4s, v12.4h\n"
+            ".inst 0x4f82e224  // sdot v4.4s, v17.16b, v2.4b[0]\n"
+            ".inst 0x4fa2e221  // sdot v1.4s, v17.16b, v2.4b[1]\n"
+            ".inst 0x4f82ea20  // sdot v0.4s, v17.16b, v2.4b[2]\n"
+            ".inst 0x4fa2ea3e  // sdot v30.4s, v17.16b, v2.4b[3]\n"
+            "fmul v11.4s, v21.4s, v12.s[0]\n"
+            "fmul v23.4s, v21.4s, v12.s[1]\n"
+            "fmul v17.4s, v21.4s, v12.s[2]\n"
+            ".inst 0x4f9fe2c4  // sdot v4.4s, v22.16b, v31.4b[0]\n"
+            "fmul v6.4s, v21.4s, v12.s[3]\n"
+            ".inst 0x4fbfe2c1  // sdot v1.4s, v22.16b, v31.4b[1]\n"
+            ".inst 0x4f9feac0  // sdot v0.4s, v22.16b, v31.4b[2]\n"
+            ".inst 0x4fbfeade  // sdot v30.4s, v22.16b, v31.4b[3]\n"
+            ".inst 0x4f9be124  // sdot v4.4s, v9.16b, v27.4b[0]\n"
+            ".inst 0x4fbbe121  // sdot v1.4s, v9.16b, v27.4b[1]\n"
+            ".inst 0x4f9be920  // sdot v0.4s, v9.16b, v27.4b[2]\n"
+            ".inst 0x4fbbe93e  // sdot v30.4s, v9.16b, v27.4b[3]\n"
+            ".inst 0x4f9ae0e4  // sdot v4.4s, v7.16b, v26.4b[0]\n"
+            ".inst 0x4fbae0e1  // sdot v1.4s, v7.16b, v26.4b[1]\n"
+            ".inst 0x4f9ae8e0  // sdot v0.4s, v7.16b, v26.4b[2]\n"
+            ".inst 0x4fbae8fe  // sdot v30.4s, v7.16b, v26.4b[3]\n"
+            ".inst 0x4f99e064  // sdot v4.4s, v3.16b, v25.4b[0]\n"
+            ".inst 0x4fb9e061  // sdot v1.4s, v3.16b, v25.4b[1]\n"
+            ".inst 0x4f99e860  // sdot v0.4s, v3.16b, v25.4b[2]\n"
+            ".inst 0x4fb9e87e  // sdot v30.4s, v3.16b, v25.4b[3]\n"
+            ".inst 0x4f98e1a4  // sdot v4.4s, v13.16b, v24.4b[0]\n"
+            ".inst 0x4fb8e1a1  // sdot v1.4s, v13.16b, v24.4b[1]\n"
+            ".inst 0x4f98e9a0  // sdot v0.4s, v13.16b, v24.4b[2]\n"
+            ".inst 0x4fb8e9be  // sdot v30.4s, v13.16b, v24.4b[3]\n"
+            ".inst 0x4f90e384  // sdot v4.4s, v28.16b, v16.4b[0]\n"
+            ".inst 0x4fb0e381  // sdot v1.4s, v28.16b, v16.4b[1]\n"
+            ".inst 0x4f90eb80  // sdot v0.4s, v28.16b, v16.4b[2]\n"
+            ".inst 0x4fb0eb9e  // sdot v30.4s, v28.16b, v16.4b[3]\n"
+            "scvtf v4.4s, v4.4s, #0x4\n"
+            "scvtf v1.4s, v1.4s, #0x4\n"
+            "scvtf v0.4s, v0.4s, #0x4\n"
+            "fmla v15.4s, v4.4s, v11.4s\n"
+            "scvtf v30.4s, v30.4s, #0x4\n"
+            "fmla v19.4s, v1.4s, v23.4s\n"
+            "fmla v18.4s, v0.4s, v17.4s\n"
+            "fmla v14.4s, v30.4s, v6.4s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x10, #0x1\n"
+            "str q15, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x2\n"
+            "str q19, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x3\n"
+            "str q18, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "str q14, [x20, #0x0]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x23, x23, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "bne 6b\n"
+            "subs x10, x10, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x9\n"
+            "mov %x[res_ptr], x22\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+        );
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4][4];
+        int sumi;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+                }
+                for (int l = 0; l < nb; l++) {
+                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi = 0;
+                                for (int i = 0; i < blocklen; ++i) {
+                                    const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                    const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                                }
+                                sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                            }
+                        }
+                    }
+                }
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++)
+                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x10, %x[nr]\n"
+            "mov x9, #0x88\n"
+            "cmp x10, #0x10\n"
+            "mul x9, %x[nb], x9\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x28, %x[b_ptr], #0x8\n"
+            "mov x27, %x[nc]\n"
+            "add x26, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x25, %x[a_ptr], #0x8\n"
+            "movi v2.16b, #0x0\n"
+            "movi v10.16b, #0x0\n"
+            "mov x24, %x[nb]\n"
+            "add x23, x25, x9\n"
+            "movi v12.16b, #0x0\n"
+            "movi v28.16b, #0x0\n"
+            "add x22, x23, x9\n"
+            "movi v11.16b, #0x0\n"
+            "movi v13.16b, #0x0\n"
+            "add x21, x22, x9\n"
+            "movi v22.16b, #0x0\n"
+            "movi v23.16b, #0x0\n"
+            "movi v25.16b, #0x0\n"
+            "movi v5.16b, #0x0\n"
+            "movi v7.16b, #0x0\n"
+            "movi v4.16b, #0x0\n"
+            "movi v6.16b, #0x0\n"
+            "movi v30.16b, #0x0\n"
+            "movi v24.16b, #0x0\n"
+            "movi v14.16b, #0x0\n"
+            "3:"  // Block loop
+            "ldr q21, [x28, #0x0]\n"
+            "ldr q16, [x28, #0x10]\n"
+            "movi v1.16b, #0x4\n"
+            "movi v19.4s, #0x0\n"
+            "ldr q27, [x25, #0x0]\n"
+            "ldr q15, [x25, #0x10]\n"
+            "movi v26.4s, #0x0\n"
+            "movi v18.4s, #0x0\n"
+            "ldr q29, [x28, #0x20]\n"
+            "ldr q3, [x28, #0x30]\n"
+            "movi v17.4s, #0x0\n"
+            "movi v0.16b, #0xf0\n"
+            "ldr d20, [x25, #-0x8]\n"
+            "ldr d9, [x23, #-0x8]\n"
+            "sshl v8.16b, v21.16b, v1.16b\n"
+            "sshl v31.16b, v16.16b, v1.16b\n"
+            "and v21.16b, v21.16b, v0.16b\n"
+            "and v16.16b, v16.16b, v0.16b\n"
+            "sub x20, x28, #0x8\n"
+            "subs x24, x24, #0x1\n"
+            "add x28, x28, #0x48\n"
+            ".inst 0x4e88a773  // smmla v19.4s, v27.16b, v8.16b\n"
+            ".inst 0x4e9fa77a  // smmla v26.4s, v27.16b, v31.16b\n"
+            "ldr q27, [x25, #0x20]\n"
+            ".inst 0x4e88a5f2  // smmla v18.4s, v15.16b, v8.16b\n"
+            ".inst 0x4e9fa5f1  // smmla v17.4s, v15.16b, v31.16b\n"
+            "sshl v15.16b, v29.16b, v1.16b\n"
+            "sshl v1.16b, v3.16b, v1.16b\n"
+            "and v29.16b, v29.16b, v0.16b\n"
+            "and v3.16b, v3.16b, v0.16b\n"
+            "ldr q0, [x25, #0x30]\n"
+            "fcvtl v20.4s, v20.4h\n"
+            ".inst 0x4e8fa773  // smmla v19.4s, v27.16b, v15.16b\n"
+            "fcvtl v9.4s, v9.4h\n"
+            ".inst 0x4e81a77a  // smmla v26.4s, v27.16b, v1.16b\n"
+            "ldr q27, [x25, #0x40]\n"
+            ".inst 0x4e8fa412  // smmla v18.4s, v0.16b, v15.16b\n"
+            ".inst 0x4e81a411  // smmla v17.4s, v0.16b, v1.16b\n"
+            "ldr q0, [x25, #0x50]\n"
+            ".inst 0x4e95a773  // smmla v19.4s, v27.16b, v21.16b\n"
+            ".inst 0x4e90a77a  // smmla v26.4s, v27.16b, v16.16b\n"
+            "ldr q27, [x25, #0x60]\n"
+            ".inst 0x4e95a412  // smmla v18.4s, v0.16b, v21.16b\n"
+            ".inst 0x4e90a411  // smmla v17.4s, v0.16b, v16.16b\n"
+            "ldr q0, [x25, #0x70]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x4e9da773  // smmla v19.4s, v27.16b, v29.16b\n"
+            ".inst 0x4e83a77a  // smmla v26.4s, v27.16b, v3.16b\n"
+            "ldr d27, [x20, #0x0]\n"
+            ".inst 0x4e9da412  // smmla v18.4s, v0.16b, v29.16b\n"
+            ".inst 0x4e83a411  // smmla v17.4s, v0.16b, v3.16b\n"
+            "fcvtl v27.4s, v27.4h\n"
+            "uzp1 v0.2d, v19.2d, v26.2d\n"
+            "uzp2 v26.2d, v19.2d, v26.2d\n"
+            "fmul v19.4s, v27.4s, v20.s[0]\n"
+            "scvtf v0.4s, v0.4s, #0x4\n"
+            "scvtf v26.4s, v26.4s, #0x4\n"
+            "fmla v2.4s, v0.4s, v19.4s\n"
+            "ldr q19, [x23, #0x0]\n"
+            "uzp1 v0.2d, v18.2d, v17.2d\n"
+            "uzp2 v18.2d, v18.2d, v17.2d\n"
+            "fmul v17.4s, v27.4s, v20.s[1]\n"
+            "scvtf v0.4s, v0.4s, #0x4\n"
+            "scvtf v18.4s, v18.4s, #0x4\n"
+            "fmla v10.4s, v26.4s, v17.4s\n"
+            "ldr q17, [x23, #0x10]\n"
+            "fmul v26.4s, v27.4s, v20.s[2]\n"
+            "fmul v20.4s, v27.4s, v20.s[3]\n"
+            "fmla v12.4s, v0.4s, v26.4s\n"
+            "ldr d0, [x22, #-0x8]\n"
+            "ldr d26, [x21, #-0x8]\n"
+            "fcvtl v0.4s, v0.4h\n"
+            "fmla v28.4s, v18.4s, v20.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v18.4s, #0x0\n"
+            ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+            ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+            "ldr q19, [x23, #0x20]\n"
+            "fcvtl v26.4s, v26.4h\n"
+            ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+            ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+            "ldr q19, [x23, #0x40]\n"
+            ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+            ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+            "ldr q19, [x23, #0x60]\n"
+            ".inst 0x4e9da674  // smmla v20.4s, v19.16b, v29.16b\n"
+            ".inst 0x4e83a672  // smmla v18.4s, v19.16b, v3.16b\n"
+            "uzp1 v19.2d, v20.2d, v18.2d\n"
+            "scvtf v19.4s, v19.4s, #0x4\n"
+            "uzp2 v20.2d, v20.2d, v18.2d\n"
+            "fmul v18.4s, v27.4s, v9.s[0]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v11.4s, v19.4s, v18.4s\n"
+            "ldr q18, [x22, #0x0]\n"
+            "fmul v19.4s, v27.4s, v9.s[1]\n"
+            "fmla v13.4s, v20.4s, v19.4s\n"
+            "movi v19.4s, #0x0\n"
+            "movi v20.4s, #0x0\n"
+            ".inst 0x4e88a633  // smmla v19.4s, v17.16b, v8.16b\n"
+            ".inst 0x4e9fa634  // smmla v20.4s, v17.16b, v31.16b\n"
+            "ldr q17, [x23, #0x30]\n"
+            ".inst 0x4e8fa633  // smmla v19.4s, v17.16b, v15.16b\n"
+            ".inst 0x4e81a634  // smmla v20.4s, v17.16b, v1.16b\n"
+            "ldr q17, [x23, #0x50]\n"
+            ".inst 0x4e95a633  // smmla v19.4s, v17.16b, v21.16b\n"
+            ".inst 0x4e90a634  // smmla v20.4s, v17.16b, v16.16b\n"
+            "ldr q17, [x23, #0x70]\n"
+            "add x23, x23, #0x88\n"
+            ".inst 0x4e9da633  // smmla v19.4s, v17.16b, v29.16b\n"
+            ".inst 0x4e83a634  // smmla v20.4s, v17.16b, v3.16b\n"
+            "uzp1 v17.2d, v19.2d, v20.2d\n"
+            "scvtf v17.4s, v17.4s, #0x4\n"
+            "uzp2 v20.2d, v19.2d, v20.2d\n"
+            "fmul v19.4s, v27.4s, v9.s[2]\n"
+            "fmul v9.4s, v27.4s, v9.s[3]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v22.4s, v17.4s, v19.4s\n"
+            "ldr q17, [x22, #0x10]\n"
+            "movi v19.4s, #0x0\n"
+            ".inst 0x4e88a653  // smmla v19.4s, v18.16b, v8.16b\n"
+            "fmla v23.4s, v20.4s, v9.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v9.4s, #0x0\n"
+            ".inst 0x4e9fa654  // smmla v20.4s, v18.16b, v31.16b\n"
+            "ldr q18, [x22, #0x20]\n"
+            ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+            ".inst 0x4e8fa653  // smmla v19.4s, v18.16b, v15.16b\n"
+            ".inst 0x4e81a654  // smmla v20.4s, v18.16b, v1.16b\n"
+            "ldr q18, [x22, #0x40]\n"
+            ".inst 0x4e95a653  // smmla v19.4s, v18.16b, v21.16b\n"
+            ".inst 0x4e90a654  // smmla v20.4s, v18.16b, v16.16b\n"
+            "ldr q18, [x22, #0x60]\n"
+            ".inst 0x4e9da653  // smmla v19.4s, v18.16b, v29.16b\n"
+            ".inst 0x4e83a654  // smmla v20.4s, v18.16b, v3.16b\n"
+            "movi v18.4s, #0x0\n"
+            ".inst 0x4e9fa632  // smmla v18.4s, v17.16b, v31.16b\n"
+            "ldr q17, [x22, #0x30]\n"
+            ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+            ".inst 0x4e81a632  // smmla v18.4s, v17.16b, v1.16b\n"
+            "ldr q17, [x22, #0x50]\n"
+            ".inst 0x4e95a629  // smmla v9.4s, v17.16b, v21.16b\n"
+            ".inst 0x4e90a632  // smmla v18.4s, v17.16b, v16.16b\n"
+            "ldr q17, [x22, #0x70]\n"
+            "add x22, x22, #0x88\n"
+            ".inst 0x4e9da629  // smmla v9.4s, v17.16b, v29.16b\n"
+            ".inst 0x4e83a632  // smmla v18.4s, v17.16b, v3.16b\n"
+            "uzp1 v17.2d, v19.2d, v20.2d\n"
+            "uzp2 v20.2d, v19.2d, v20.2d\n"
+            "fmul v19.4s, v27.4s, v0.s[0]\n"
+            "scvtf v17.4s, v17.4s, #0x4\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "fmla v25.4s, v17.4s, v19.4s\n"
+            "ldr q19, [x21, #0x0]\n"
+            "fmul v17.4s, v27.4s, v0.s[1]\n"
+            "fmla v5.4s, v20.4s, v17.4s\n"
+            "ldr q17, [x21, #0x10]\n"
+            "uzp1 v20.2d, v9.2d, v18.2d\n"
+            "uzp2 v9.2d, v9.2d, v18.2d\n"
+            "fmul v18.4s, v27.4s, v0.s[2]\n"
+            "fmul v0.4s, v27.4s, v0.s[3]\n"
+            "scvtf v20.4s, v20.4s, #0x4\n"
+            "scvtf v9.4s, v9.4s, #0x4\n"
+            "fmla v7.4s, v20.4s, v18.4s\n"
+            "movi v20.4s, #0x0\n"
+            "movi v18.4s, #0x0\n"
+            ".inst 0x4e88a674  // smmla v20.4s, v19.16b, v8.16b\n"
+            ".inst 0x4e9fa672  // smmla v18.4s, v19.16b, v31.16b\n"
+            "ldr q19, [x21, #0x20]\n"
+            "fmla v4.4s, v9.4s, v0.4s\n"
+            "movi v9.4s, #0x0\n"
+            "movi v0.4s, #0x0\n"
+            ".inst 0x4e88a629  // smmla v9.4s, v17.16b, v8.16b\n"
+            "fmul v8.4s, v27.4s, v26.s[0]\n"
+            ".inst 0x4e9fa620  // smmla v0.4s, v17.16b, v31.16b\n"
+            "ldr q17, [x21, #0x30]\n"
+            ".inst 0x4e8fa674  // smmla v20.4s, v19.16b, v15.16b\n"
+            "fmul v31.4s, v27.4s, v26.s[1]\n"
+            ".inst 0x4e81a672  // smmla v18.4s, v19.16b, v1.16b\n"
+            "ldr q19, [x21, #0x40]\n"
+            ".inst 0x4e8fa629  // smmla v9.4s, v17.16b, v15.16b\n"
+            "fmul v15.4s, v27.4s, v26.s[2]\n"
+            "fmul v27.4s, v27.4s, v26.s[3]\n"
+            ".inst 0x4e81a620  // smmla v0.4s, v17.16b, v1.16b\n"
+            "ldr q1, [x21, #0x50]\n"
+            ".inst 0x4e95a674  // smmla v20.4s, v19.16b, v21.16b\n"
+            ".inst 0x4e90a672  // smmla v18.4s, v19.16b, v16.16b\n"
+            "ldr q26, [x21, #0x60]\n"
+            ".inst 0x4e95a429  // smmla v9.4s, v1.16b, v21.16b\n"
+            ".inst 0x4e90a420  // smmla v0.4s, v1.16b, v16.16b\n"
+            "ldr q21, [x21, #0x70]\n"
+            "add x21, x21, #0x88\n"
+            ".inst 0x4e9da754  // smmla v20.4s, v26.16b, v29.16b\n"
+            ".inst 0x4e83a752  // smmla v18.4s, v26.16b, v3.16b\n"
+            ".inst 0x4e9da6a9  // smmla v9.4s, v21.16b, v29.16b\n"
+            ".inst 0x4e83a6a0  // smmla v0.4s, v21.16b, v3.16b\n"
+            "uzp1 v29.2d, v20.2d, v18.2d\n"
+            "uzp2 v21.2d, v20.2d, v18.2d\n"
+            "scvtf v29.4s, v29.4s, #0x4\n"
+            "uzp1 v18.2d, v9.2d, v0.2d\n"
+            "uzp2 v16.2d, v9.2d, v0.2d\n"
+            "scvtf v21.4s, v21.4s, #0x4\n"
+            "fmla v6.4s, v29.4s, v8.4s\n"
+            "scvtf v18.4s, v18.4s, #0x4\n"
+            "scvtf v16.4s, v16.4s, #0x4\n"
+            "fmla v30.4s, v21.4s, v31.4s\n"
+            "fmla v24.4s, v18.4s, v15.4s\n"
+            "fmla v14.4s, v16.4s, v27.4s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x27, x27, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "str q2, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q10, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q12, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q28, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q11, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q13, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q22, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q23, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q25, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q5, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q7, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q4, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q6, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q30, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q24, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "str q14, [x20, #0x0]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x10, x10, #0x10\n"
+            "cmp x10, #0x10\n"
+            "mov %x[res_ptr], x26\n"
+            "madd %x[a_ptr], x20, x9, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x10, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x24, %x[b_ptr], #0x8\n"
+            "mov x23, %x[nc]\n"
+            "add x22, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "movi v2.16b, #0x0\n"
+            "movi v10.16b, #0x0\n"
+            "add x25, %x[a_ptr], #0x8\n"
+            "mov x21, %x[nb]\n"
+            "movi v12.16b, #0x0\n"
+            "movi v28.16b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ldr q6, [x24, #0x0]\n"
+            "ldr q5, [x24, #0x10]\n"
+            "movi v17.16b, #0x4\n"
+            "movi v8.4s, #0x0\n"
+            "ldr q4, [x25, #0x0]\n"
+            "ldr q13, [x25, #0x10]\n"
+            "movi v27.4s, #0x0\n"
+            "movi v0.4s, #0x0\n"
+            "ldr q31, [x24, #0x20]\n"
+            "ldr q14, [x24, #0x30]\n"
+            "movi v29.4s, #0x0\n"
+            "movi v22.16b, #0xf0\n"
+            "ldr q11, [x25, #0x20]\n"
+            "ldr q23, [x25, #0x30]\n"
+            "sshl v21.16b, v6.16b, v17.16b\n"
+            "sshl v16.16b, v5.16b, v17.16b\n"
+            "ldr q20, [x25, #0x40]\n"
+            "ldr q26, [x25, #0x50]\n"
+            "and v6.16b, v6.16b, v22.16b\n"
+            "and v5.16b, v5.16b, v22.16b\n"
+            "ldr q25, [x25, #0x60]\n"
+            "ldr q3, [x25, #0x70]\n"
+            "sshl v19.16b, v31.16b, v17.16b\n"
+            "sshl v18.16b, v14.16b, v17.16b\n"
+            "ldr d17, [x25, #-0x8]\n"
+            ".inst 0x4e95a488  // smmla v8.4s, v4.16b, v21.16b\n"
+            ".inst 0x4e90a49b  // smmla v27.4s, v4.16b, v16.16b\n"
+            "and v31.16b, v31.16b, v22.16b\n"
+            ".inst 0x4e95a5a0  // smmla v0.4s, v13.16b, v21.16b\n"
+            ".inst 0x4e90a5bd  // smmla v29.4s, v13.16b, v16.16b\n"
+            "and v14.16b, v14.16b, v22.16b\n"
+            "sub x20, x24, #0x8\n"
+            "ldr d16, [x20, #0x0]\n"
+            "subs x21, x21, #0x1\n"
+            "add x25, x25, #0x88\n"
+            "fcvtl v17.4s, v17.4h\n"
+            "add x24, x24, #0x48\n"
+            ".inst 0x4e93a568  // smmla v8.4s, v11.16b, v19.16b\n"
+            ".inst 0x4e92a57b  // smmla v27.4s, v11.16b, v18.16b\n"
+            ".inst 0x4e93a6e0  // smmla v0.4s, v23.16b, v19.16b\n"
+            ".inst 0x4e92a6fd  // smmla v29.4s, v23.16b, v18.16b\n"
+            "fcvtl v16.4s, v16.4h\n"
+            ".inst 0x4e86a688  // smmla v8.4s, v20.16b, v6.16b\n"
+            ".inst 0x4e85a69b  // smmla v27.4s, v20.16b, v5.16b\n"
+            "fmul v23.4s, v16.4s, v17.s[0]\n"
+            "fmul v21.4s, v16.4s, v17.s[1]\n"
+            "fmul v1.4s, v16.4s, v17.s[2]\n"
+            "fmul v20.4s, v16.4s, v17.s[3]\n"
+            ".inst 0x4e86a740  // smmla v0.4s, v26.16b, v6.16b\n"
+            ".inst 0x4e85a75d  // smmla v29.4s, v26.16b, v5.16b\n"
+            ".inst 0x4e9fa728  // smmla v8.4s, v25.16b, v31.16b\n"
+            ".inst 0x4e8ea73b  // smmla v27.4s, v25.16b, v14.16b\n"
+            ".inst 0x4e9fa460  // smmla v0.4s, v3.16b, v31.16b\n"
+            ".inst 0x4e8ea47d  // smmla v29.4s, v3.16b, v14.16b\n"
+            "uzp1 v19.2d, v8.2d, v27.2d\n"
+            "uzp2 v18.2d, v8.2d, v27.2d\n"
+            "scvtf v19.4s, v19.4s, #0x4\n"
+            "uzp1 v17.2d, v0.2d, v29.2d\n"
+            "uzp2 v16.2d, v0.2d, v29.2d\n"
+            "scvtf v18.4s, v18.4s, #0x4\n"
+            "fmla v2.4s, v19.4s, v23.4s\n"
+            "scvtf v17.4s, v17.4s, #0x4\n"
+            "scvtf v16.4s, v16.4s, #0x4\n"
+            "fmla v10.4s, v18.4s, v21.4s\n"
+            "fmla v12.4s, v17.4s, v1.4s\n"
+            "fmla v28.4s, v16.4s, v20.4s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x10, #0x1\n"
+            "str q2, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x2\n"
+            "str q10, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x10, #0x3\n"
+            "str q12, [x20, #0x0]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "str q28, [x20, #0x0]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x23, x23, #0x4\n"
+            "add %x[res_ptr], %x[res_ptr], #0x10\n"
+            "bne 6b\n"
+            "subs x10, x10, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x9\n"
+            "mov %x[res_ptr], x22\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"
+        );
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                        (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+    if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
+        const void * b_ptr = vx;
+        const void * a_ptr = vy;
+        float * res_ptr = s;
+        size_t res_stride = bs * sizeof(float);
+
+        __asm__ __volatile__(
+            "mov x20, #0x4\n"
+            "mov x13, %x[nr]\n"
+            "mov z28.s, #-0x4\n"
+            "mov x12, #0x88\n"
+            "ptrue p1.b\n"
+            "whilelt p0.s, XZR, x20\n"
+            "cmp x13, #0x10\n"
+            "mul x12, %x[nb], x12\n"
+            "blt 4f\n"
+            "1:"  // Row loop
+            "add x11, %x[b_ptr], #0x10\n"
+            "mov x10, %x[nc]\n"
+            "add x9, %x[res_ptr], %x[res_stride], LSL #4\n"
+            "2:"  // Column loop
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "mov x27, %x[nb]\n"
+            "add x26, x28, x12\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "add x25, x26, x12\n"
+            "mov z13.b, #0x0\n"
+            "mov z1.b, #0x0\n"
+            "add x24, x25, x12\n"
+            "mov z20.b, #0x0\n"
+            "mov z25.b, #0x0\n"
+            "mov z11.b, #0x0\n"
+            "mov z16.b, #0x0\n"
+            "mov z19.b, #0x0\n"
+            "mov z26.b, #0x0\n"
+            "mov z8.b, #0x0\n"
+            "mov z29.b, #0x0\n"
+            "mov z27.b, #0x0\n"
+            "mov z10.b, #0x0\n"
+            "3:"  // Block loop
+            "ld1b { z30.b }, p1/Z, [x11]\n"
+            "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n"
+            "mov z18.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            "ld1rqb { z3.b }, p1/Z, [x28]\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #16]\n"
+            "mov z9.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n"
+            "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n"
+            "sub x20, x11, #0x10\n"
+            "sub x23, x28, #0x8\n"
+            "lsl z31.b, z30.b, #0x4\n"
+            "lsl z6.b, z21.b, #0x4\n"
+            "ld1h { z23.s }, p1/Z, [x20]\n"
+            "sub x22, x26, #0x8\n"
+            "and z30.b, z30.b, #0xf0\n"
+            "and z21.b, z21.b, #0xf0\n"
+            "sub x21, x25, #0x8\n"
+            "sub x20, x24, #0x8\n"
+            "lsl z14.b, z4.b, #0x4\n"
+            "lsl z2.b, z17.b, #0x4\n"
+            "subs x27, x27, #0x1\n"
+            "add x11, x11, #0x90\n"
+            ".inst 0x451f9872  // smmla z18.s, z3.b, z31.b\n"
+            ".inst 0x45069867  // smmla z7.s, z3.b, z6.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #32]\n"
+            "and z4.b, z4.b, #0xf0\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #48]\n"
+            "and z17.b, z17.b, #0xf0\n"
+            "fcvt z23.s, p1/m, z23.h\n"
+            ".inst 0x450e9872  // smmla z18.s, z3.b, z14.b\n"
+            ".inst 0x45029867  // smmla z7.s, z3.b, z2.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #64]\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #80]\n"
+            "fscale z23.s, p1/m, z23.s, z28.s\n"
+            ".inst 0x451e9872  // smmla z18.s, z3.b, z30.b\n"
+            ".inst 0x45159867  // smmla z7.s, z3.b, z21.b\n"
+            "ld1rqb { z3.b }, p1/Z, [x28, #96]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x28, #112]\n"
+            "add x28, x28, #0x88\n"
+            ".inst 0x45049872  // smmla z18.s, z3.b, z4.b\n"
+            ".inst 0x45119867  // smmla z7.s, z3.b, z17.b\n"
+            "ld1h { z3.s }, p0/Z, [x23]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            "uzp1 z5.d, z18.d, z7.d\n"
+            "uzp2 z18.d, z18.d, z7.d\n"
+            "mov z3.q, z3.q[0]\n"
+            "uzp1 z7.d, z9.d, z22.d\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z3.s[0]\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z24.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z5.b }, p1/Z, [x26]\n"
+            "fmul z9.s, z23.s, z3.s[1]\n"
+            "fmla z15.s, p1/M, z18.s, z9.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #16]\n"
+            "fmul z9.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "fmla z12.s, p1/M, z7.s, z9.s\n"
+            "mov z9.s, #0x0\n"
+            "ld1h { z7.s }, p0/Z, [x22]\n"
+            ".inst 0x451f98a9  // smmla z9.s, z5.b, z31.b\n"
+            "fmla z0.s, p1/M, z22.s, z3.s\n"
+            "mov z22.s, #0x0\n"
+            "ld1h { z3.s }, p0/Z, [x21]\n"
+            ".inst 0x450698b6  // smmla z22.s, z5.b, z6.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #32]\n"
+            "fcvt z7.s, p1/m, z7.h\n"
+            "fcvt z3.s, p1/m, z3.h\n"
+            ".inst 0x450e98a9  // smmla z9.s, z5.b, z14.b\n"
+            ".inst 0x450298b6  // smmla z22.s, z5.b, z2.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #64]\n"
+            "mov z7.q, z7.q[0]\n"
+            "mov z3.q, z3.q[0]\n"
+            ".inst 0x451e98a9  // smmla z9.s, z5.b, z30.b\n"
+            ".inst 0x451598b6  // smmla z22.s, z5.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x26, #96]\n"
+            ".inst 0x450498a9  // smmla z9.s, z5.b, z4.b\n"
+            ".inst 0x451198b6  // smmla z22.s, z5.b, z17.b\n"
+            "uzp1 z5.d, z9.d, z22.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "uzp2 z22.d, z9.d, z22.d\n"
+            "fmul z9.s, z23.s, z7.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z13.s, p1/M, z5.s, z9.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x25]\n"
+            "fmul z5.s, z23.s, z7.s[1]\n"
+            "fmla z1.s, p1/M, z22.s, z5.s\n"
+            "mov z5.s, #0x0\n"
+            "mov z22.s, #0x0\n"
+            ".inst 0x451f9a45  // smmla z5.s, z18.b, z31.b\n"
+            ".inst 0x45069a56  // smmla z22.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #48]\n"
+            ".inst 0x450e9a45  // smmla z5.s, z18.b, z14.b\n"
+            ".inst 0x45029a56  // smmla z22.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #80]\n"
+            ".inst 0x451e9a45  // smmla z5.s, z18.b, z30.b\n"
+            ".inst 0x45159a56  // smmla z22.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x26, #112]\n"
+            "add x26, x26, #0x88\n"
+            ".inst 0x45049a45  // smmla z5.s, z18.b, z4.b\n"
+            ".inst 0x45119a56  // smmla z22.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z5.d, z22.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z22.d, z5.d, z22.d\n"
+            "fmul z5.s, z23.s, z7.s[2]\n"
+            "fmul z7.s, z23.s, z7.s[3]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z20.s, p1/M, z18.s, z5.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #16]\n"
+            "ld1h { z5.s }, p0/Z, [x20]\n"
+            "fcvt z5.s, p1/m, z5.h\n"
+            "fmla z25.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9936  // smmla z22.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #32]\n"
+            "mov z5.q, z5.q[0]\n"
+            ".inst 0x450e9936  // smmla z22.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #64]\n"
+            ".inst 0x451e9936  // smmla z22.s, z9.b, z30.b\n"
+            ".inst 0x45159927  // smmla z7.s, z9.b, z21.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x25, #96]\n"
+            ".inst 0x45049936  // smmla z22.s, z9.b, z4.b\n"
+            ".inst 0x45119927  // smmla z7.s, z9.b, z17.b\n"
+            "uzp1 z9.d, z22.d, z7.d\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "uzp2 z22.d, z22.d, z7.d\n"
+            "fmul z7.s, z23.s, z3.s[0]\n"
+            "scvtf z22.s, p1/m, z22.s\n"
+            "fmla z11.s, p1/M, z9.s, z7.s\n"
+            "ld1rqb { z9.b }, p1/Z, [x24]\n"
+            "fmul z7.s, z23.s, z3.s[1]\n"
+            "fmla z16.s, p1/M, z22.s, z7.s\n"
+            "mov z22.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9a56  // smmla z22.s, z18.b, z31.b\n"
+            ".inst 0x45069a47  // smmla z7.s, z18.b, z6.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #48]\n"
+            ".inst 0x450e9a56  // smmla z22.s, z18.b, z14.b\n"
+            ".inst 0x45029a47  // smmla z7.s, z18.b, z2.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #80]\n"
+            ".inst 0x451e9a56  // smmla z22.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x25, #112]\n"
+            "add x25, x25, #0x88\n"
+            ".inst 0x45049a56  // smmla z22.s, z18.b, z4.b\n"
+            ".inst 0x45119a47  // smmla z7.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z22.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp2 z7.d, z22.d, z7.d\n"
+            "fmul z22.s, z23.s, z3.s[2]\n"
+            "fmul z3.s, z23.s, z3.s[3]\n"
+            "scvtf z7.s, p1/m, z7.s\n"
+            "fmla z19.s, p1/M, z18.s, z22.s\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #16]\n"
+            "fmul z22.s, z23.s, z5.s[0]\n"
+            "fmla z26.s, p1/M, z7.s, z3.s\n"
+            "mov z3.s, #0x0\n"
+            "mov z7.s, #0x0\n"
+            ".inst 0x451f9923  // smmla z3.s, z9.b, z31.b\n"
+            ".inst 0x45069927  // smmla z7.s, z9.b, z6.b\n"
+            "ld1rqb { z9.b }, p1/Z, [x24, #32]\n"
+            ".inst 0x450e9923  // smmla z3.s, z9.b, z14.b\n"
+            ".inst 0x45029927  // smmla z7.s, z9.b, z2.b\n"
+            "mov z9.s, #0x0\n"
+            ".inst 0x451f9a49  // smmla z9.s, z18.b, z31.b\n"
+            "mov z31.s, #0x0\n"
+            ".inst 0x45069a5f  // smmla z31.s, z18.b, z6.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #48]\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #64]\n"
+            ".inst 0x450e98c9  // smmla z9.s, z6.b, z14.b\n"
+            "fmul z14.s, z23.s, z5.s[1]\n"
+            ".inst 0x450298df  // smmla z31.s, z6.b, z2.b\n"
+            "ld1rqb { z6.b }, p1/Z, [x24, #80]\n"
+            "fmul z2.s, z23.s, z5.s[2]\n"
+            "fmul z23.s, z23.s, z5.s[3]\n"
+            ".inst 0x451e9a43  // smmla z3.s, z18.b, z30.b\n"
+            ".inst 0x45159a47  // smmla z7.s, z18.b, z21.b\n"
+            "ld1rqb { z5.b }, p1/Z, [x24, #96]\n"
+            ".inst 0x451e98c9  // smmla z9.s, z6.b, z30.b\n"
+            ".inst 0x451598df  // smmla z31.s, z6.b, z21.b\n"
+            "ld1rqb { z18.b }, p1/Z, [x24, #112]\n"
+            "add x24, x24, #0x88\n"
+            ".inst 0x450498a3  // smmla z3.s, z5.b, z4.b\n"
+            ".inst 0x451198a7  // smmla z7.s, z5.b, z17.b\n"
+            ".inst 0x45049a49  // smmla z9.s, z18.b, z4.b\n"
+            ".inst 0x45119a5f  // smmla z31.s, z18.b, z17.b\n"
+            "uzp1 z18.d, z3.d, z7.d\n"
+            "uzp2 z5.d, z3.d, z7.d\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "uzp1 z6.d, z9.d, z31.d\n"
+            "uzp2 z9.d, z9.d, z31.d\n"
+            "scvtf z5.s, p1/m, z5.s\n"
+            "fmla z8.s, p1/M, z18.s, z22.s\n"
+            "scvtf z6.s, p1/m, z6.s\n"
+            "scvtf z9.s, p1/m, z9.s\n"
+            "fmla z29.s, p1/M, z5.s, z14.s\n"
+            "fmla z27.s, p1/M, z6.s, z2.s\n"
+            "fmla z10.s, p1/M, z9.s, z23.s\n"
+            "bgt 3b\n"
+            "mov x20, %x[res_ptr]\n"
+            "subs x10, x10, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z13.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z1.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z20.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z25.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z11.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z16.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z19.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z26.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z8.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z29.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z27.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "st1w { z10.s }, p1, [x20]\n"
+            "bne 2b\n"
+            "mov x20, #0x4\n"
+            "sub x13, x13, #0x10\n"
+            "cmp x13, #0x10\n"
+            "mov %x[res_ptr], x9\n"
+            "madd %x[a_ptr], x20, x12, %x[a_ptr]\n"
+            "bge 1b\n"
+            "4:"  // Row loop skip
+            "cbz x13, 9f\n"
+            "5:"  // Row tail: Row loop
+            "add x25, %x[b_ptr], #0x10\n"
+            "mov x24, %x[nc]\n"
+            "add x23, %x[res_ptr], %x[res_stride], LSL #2\n"
+            "6:"  // Row tail: Column loop
+            "mov z24.b, #0x0\n"
+            "mov z15.b, #0x0\n"
+            "add x28, %x[a_ptr], #0x8\n"
+            "mov x22, %x[nb]\n"
+            "mov z12.b, #0x0\n"
+            "mov z0.b, #0x0\n"
+            "7:"  // Row tail: Block loop
+            "ld1b { z3.b }, p1/Z, [x25]\n"
+            "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n"
+            "mov z2.s, #0x0\n"
+            "mov z25.s, #0x0\n"
+            "ld1rqb { z26.b }, p1/Z, [x28]\n"
+            "ld1rqb { z21.b }, p1/Z, [x28, #16]\n"
+            "mov z27.s, #0x0\n"
+            "mov z19.s, #0x0\n"
+            "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n"
+            "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n"
+            "sub x21, x25, #0x10\n"
+            "sub x20, x28, #0x8\n"
+            "lsl z20.b, z3.b, #0x4\n"
+            "lsl z4.b, z6.b, #0x4\n"
+            "ld1rqb { z10.b }, p1/Z, [x28, #32]\n"
+            "ld1rqb { z23.b }, p1/Z, [x28, #48]\n"
+            "and z3.b, z3.b, #0xf0\n"
+            "and z6.b, z6.b, #0xf0\n"
+            "ld1rqb { z11.b }, p1/Z, [x28, #64]\n"
+            "ld1rqb { z7.b }, p1/Z, [x28, #80]\n"
+            "lsl z8.b, z29.b, #0x4\n"
+            "lsl z14.b, z16.b, #0x4\n"
+            "ld1rqb { z18.b }, p1/Z, [x28, #96]\n"
+            "ld1rqb { z30.b }, p1/Z, [x28, #112]\n"
+            ".inst 0x45149b42  // smmla z2.s, z26.b, z20.b\n"
+            ".inst 0x45049b59  // smmla z25.s, z26.b, z4.b\n"
+            "and z29.b, z29.b, #0xf0\n"
+            "ld1h { z17.s }, p1/Z, [x21]\n"
+            ".inst 0x45149abb  // smmla z27.s, z21.b, z20.b\n"
+            ".inst 0x45049ab3  // smmla z19.s, z21.b, z4.b\n"
+            "and z16.b, z16.b, #0xf0\n"
+            "ld1h { z4.s }, p0/Z, [x20]\n"
+            "subs x22, x22, #0x1\n"
+            "add x28, x28, #0x88\n"
+            "fcvt z17.s, p1/m, z17.h\n"
+            "add x25, x25, #0x90\n"
+            ".inst 0x45089942  // smmla z2.s, z10.b, z8.b\n"
+            ".inst 0x450e9959  // smmla z25.s, z10.b, z14.b\n"
+            "fcvt z4.s, p1/m, z4.h\n"
+            ".inst 0x45089afb  // smmla z27.s, z23.b, z8.b\n"
+            ".inst 0x450e9af3  // smmla z19.s, z23.b, z14.b\n"
+            "fscale z17.s, p1/m, z17.s, z28.s\n"
+            "mov z4.q, z4.q[0]\n"
+            ".inst 0x45039962  // smmla z2.s, z11.b, z3.b\n"
+            ".inst 0x45069979  // smmla z25.s, z11.b, z6.b\n"
+            "fmul z23.s, z17.s, z4.s[0]\n"
+            "fmul z9.s, z17.s, z4.s[1]\n"
+            "fmul z21.s, z17.s, z4.s[2]\n"
+            "fmul z4.s, z17.s, z4.s[3]\n"
+            ".inst 0x450398fb  // smmla z27.s, z7.b, z3.b\n"
+            ".inst 0x450698f3  // smmla z19.s, z7.b, z6.b\n"
+            ".inst 0x451d9a42  // smmla z2.s, z18.b, z29.b\n"
+            ".inst 0x45109a59  // smmla z25.s, z18.b, z16.b\n"
+            ".inst 0x451d9bdb  // smmla z27.s, z30.b, z29.b\n"
+            ".inst 0x45109bd3  // smmla z19.s, z30.b, z16.b\n"
+            "uzp1 z31.d, z2.d, z25.d\n"
+            "uzp2 z13.d, z2.d, z25.d\n"
+            "scvtf z31.s, p1/m, z31.s\n"
+            "uzp1 z17.d, z27.d, z19.d\n"
+            "uzp2 z18.d, z27.d, z19.d\n"
+            "scvtf z13.s, p1/m, z13.s\n"
+            "fmla z24.s, p1/M, z31.s, z23.s\n"
+            "scvtf z17.s, p1/m, z17.s\n"
+            "scvtf z18.s, p1/m, z18.s\n"
+            "fmla z15.s, p1/M, z13.s, z9.s\n"
+            "fmla z12.s, p1/M, z17.s, z21.s\n"
+            "fmla z0.s, p1/M, z18.s, z4.s\n"
+            "bgt 7b\n"
+            "mov x20, %x[res_ptr]\n"
+            "cmp x13, #0x1\n"
+            "st1w { z24.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x2\n"
+            "st1w { z15.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "cmp x13, #0x3\n"
+            "st1w { z12.s }, p1, [x20]\n"
+            "add x20, x20, %x[res_stride]\n"
+            "ble 8f\n"
+            "st1w { z0.s }, p1, [x20]\n"
+            "8:"  // Row tail: Accumulator store skip
+            "subs x24, x24, #0x8\n"
+            "add %x[res_ptr], %x[res_ptr], #0x20\n"
+            "bne 6b\n"
+            "subs x13, x13, #0x4\n"
+            "add %x[a_ptr], %x[a_ptr], x12\n"
+            "mov %x[res_ptr], x23\n"
+            "bgt 5b\n"
+            "9:"  // Row tail: Row loop skip
+            : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr)
+            : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc)
+            : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
+        );
+        return;
+    }
+#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
+#elif defined(__AVX2__) || defined(__AVX512F__)
+    {
+        const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx;
+        const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy;
+        int64_t b_nb = n / QK4_0;
+        int64_t y = 0;
+        // Mask to mask out nibbles from packed bytes
+        const __m256i m4b = _mm256_set1_epi8(0x0F);
+        const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3);
+        // Lookup table to convert signed nibbles to signed bytes
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+        // Permute mask used for easier vector processing at later stages
+        __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4);
+        int64_t xstart = 0;
+        int anr = nr - nr%16; // Used to align nr with boundary of 16
+    #ifdef __AVX512F__
+        int anc = nc - nc%16; // Used to align nc with boundary of 16
+        // Mask to mask out nibbles from packed bytes expanded to 512 bit length
+        const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
+        // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
+        __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
+
+        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        for (; y < anr / 4; y += 4) {
+
+            const block_q8_0x4 * a_ptrs[4];
+
+            a_ptrs[0] = a_ptr_start + (y * nb);
+            for (int i = 0; i < 3; ++i) {
+                a_ptrs[i + 1] = a_ptrs[i] + nb;
+            }
+
+            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+            for (int64_t x = 0; x < anc / 8; x += 2) {
+
+                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
+                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+                // Master FP accumulators
+                __m512 acc_rows[16];
+                for (int i = 0; i < 16; i++) {
+                    acc_rows[i] = _mm512_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+
+                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+
+                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+
+                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+
+                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+
+                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+
+                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+                    // Scale values - Load the weight scale values of two block_q4_0x8
+                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                    // Process LHS in pairs of rows
+                    for (int rp = 0; rp < 4; rp++) {
+
+                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                        __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+                        __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+                        __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+                        __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+                        __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+                        __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+                        __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+                        __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+                        __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+                        __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+                        __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+                        __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+                        __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+                        __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+                        __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+                        __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+                        __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+                        __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+                        __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+                        __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+                        // Shuffle pattern one - left side input
+
+                        const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                        const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                        const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                        const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                        const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                        const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                        const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                        const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                        // Shuffle pattern two - left side input
+
+                        const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                        const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                        const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                        const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                        const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                        const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                        const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                        const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        // Resembles MMLAs into 2x2 matrices in ARM Version
+                        __m512i iacc_mat_00_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+                        __m512i iacc_mat_01_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                        __m512i iacc_mat_10_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+                        __m512i iacc_mat_11_sp1 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                        __m512i iacc_mat_00_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+                        __m512i iacc_mat_01_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+                        __m512i iacc_mat_10_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+                        __m512i iacc_mat_11_sp2 =
+                            _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                        __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                        __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                        __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                        // Straighten out to make 4 row vectors
+                        __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
+                        __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
+                        __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
+
+                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                        const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68);
+                        const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+                        // Multiply with appropiate scales and accumulate
+                        acc_rows[rp * 4]     = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[rp * 4]);
+                        acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
+                    }
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 16; i++) {
+                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        for (; y < nr / 4; y ++) {
+
+            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+            // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation
+            for (int64_t x = 0; x < anc / 8; x += 2) {
+
+                const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x)     * b_nb);
+                const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb);
+
+                // Master FP accumulators
+                __m512 acc_rows[4];
+                for (int i = 0; i < 4; i++) {
+                    acc_rows[i] = _mm512_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96));
+
+                    const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs));
+                    const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32));
+                    const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64));
+                    const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240);
+                    const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240);
+
+                    const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1);
+                    const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1);
+                    const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7)
+                    const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7)
+
+                    const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15)
+                    const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15)
+
+                    const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23)
+                    const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23)
+
+                    const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31)
+                    const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3)
+                    const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3)
+
+                    const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11)
+                    const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11)
+
+                    const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19)
+                    const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19)
+
+                    const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27)
+                    const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7)
+                    const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7)
+
+                    const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15)
+                    const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15)
+
+                    const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23)
+                    const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23)
+
+                    const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31)
+                    const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31)
+
+
+                    // Scale values - Load the weight scale values of two block_q4_0x8
+                    const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+
+                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector
+                    __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+                    __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0);
+                    __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17);
+                    __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+                    __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0);
+                    __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17);
+                    __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+                    __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0);
+                    __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17);
+                    __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+                    __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0);
+                    __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17);
+
+                    __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1);
+                    __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1);
+                    __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1);
+                    __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1);
+                    __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1);
+                    __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1);
+                    __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1);
+                    __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1);
+
+                    // Shuffle pattern one - left side input
+
+                    const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                    const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                    const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                    const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                    const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                    const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                    const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                    const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                    // Shuffle pattern two - left side input
+
+                    const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                    const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                    const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                    const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                    const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                    const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                    const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                    const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    // Resembles MMLAs into 2x2 matrices in ARM Version
+                    __m512i iacc_mat_00_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
+                    __m512i iacc_mat_01_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                    __m512i iacc_mat_10_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
+                    __m512i iacc_mat_11_sp1 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
+                    __m512i iacc_mat_00_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
+                    __m512i iacc_mat_01_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
+                    __m512i iacc_mat_10_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
+                    __m512i iacc_mat_11_sp2 =
+                        _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                    __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                    __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                    __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                    // Straighten out to make 4 row vectors
+                    __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01);
+                    __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78));
+                    __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11);
+
+                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                    const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68);
+                    const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16);
+
+                    // Multiply with appropiate scales and accumulate
+                    acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)),   acc_rows[0]);
+                    acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)),  acc_rows[1]);
+                    acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 4; i++) {
+                    _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+        if (anc != nc) {
+            xstart = anc/8;
+            y = 0;
+        }
+    #endif // __AVX512F__
+
+        // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
+
+        for (; y < anr / 4; y += 4) {
+            const block_q8_0x4 * a_ptrs[4];
+
+            a_ptrs[0] = a_ptr_start + (y * nb);
+            for (int i = 0; i < 3; ++i) {
+                a_ptrs[i + 1] = a_ptrs[i] + nb;
+            }
+
+            // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation
+            for (int64_t x = xstart; x < nc / 8; x++) {
+
+                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+                // Master FP accumulators
+                __m256 acc_rows[16];
+                for (int i = 0; i < 16; i++) {
+                    acc_rows[i] = _mm256_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+                    // Scale values - Load the wight scale values of block_q4_0x8
+                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                    // Process LHS in groups of four
+                    for (int rp = 0; rp < 4; rp++) {
+                        // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                        // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                        __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
+                        __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+                        __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+                        __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
+                        __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+                        __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+                        __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
+                        __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+                        __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+                        __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
+                        __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+                        __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+                        // Shuffle pattern one - left side input
+                        const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                        const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                        const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                        const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                        const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                        const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                        const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                        const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                        // Shuffle pattern two - left side input
+                        const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                        const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                        const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                        const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                        const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                        const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                        const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                        const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                        // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                        // Resembles MMLAs into 2x2 matrices in ARM Version
+                        __m256i iacc_mat_00_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+                        __m256i iacc_mat_01_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+                        __m256i iacc_mat_10_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+                        __m256i iacc_mat_11_sp1 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+                        __m256i iacc_mat_00_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+                        __m256i iacc_mat_01_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+                        __m256i iacc_mat_10_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+                        __m256i iacc_mat_11_sp2 =
+                            _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+
+                        // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                        __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                        __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                        __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                        __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+                        // Straighten out to make 4 row vectors
+                        __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                        __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                        __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                        __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+                        // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                        const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask);
+
+                        // Multiply with appropiate scales and accumulate
+                        acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]);
+                        acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]);
+                        acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
+                        acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32,  255)), acc_rows[rp * 4 + 3]);
+                    }
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 16; i++) {
+                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+
+        // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation
+        for (; y < nr / 4; y ++) {
+
+            const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb);
+
+            // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+            for (int64_t x = xstart; x < nc / 8; x++) {
+
+                const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb);
+
+                // Master FP accumulators
+                __m256 acc_rows[4];
+                for (int i = 0; i < 4; i++) {
+                    acc_rows[i] = _mm256_setzero_ps();
+                }
+
+                for (int64_t b = 0; b < nb; b++) {
+                    // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
+                    const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs));
+                    const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32));
+                    const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64));
+                    const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96));
+
+                    // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess
+                    const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240);
+                    const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240);
+                    const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240);
+
+                    // 4-bit -> 8-bit - Sign is maintained
+                    const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b));  //B0(0-7) B1(0-7) B4(0-7) B5(0-7)
+                    const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b));  //B2(0-7) B3(0-7) B6(0-7) B7(0-7)
+
+                    const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b));  //B0(8-15) B1(8-15) B4(8-15) B5(8-15)
+                    const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b));  //B2(8-15) B3(8-15) B6(8-15) B7(8-15)
+
+                    const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b));  //B0(16-23) B1(16-23) B4(16-23) B5(16-23)
+                    const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b));  //B2(16-23) B3(16-23) B6(16-23) B7(16-23)
+
+                    const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b));  //B0(24-31) B1(24-31) B4(24-31) B5(24-31)
+                    const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b));  //B2(24-31) B3(24-31) B6(24-31) B7(24-31)
+
+                    // Shuffle pattern one - right side input
+                    const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136);  //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3)
+                    const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136);  //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3)
+
+                    const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136);  //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11)
+                    const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136);  //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11)
+
+                    const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136);  //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19)
+                    const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136);  //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19)
+
+                    const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136);  //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27)
+                    const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136);  //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27)
+
+                    // Shuffle pattern two - right side input
+
+                    const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221);  //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7)
+                    const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221);  //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7)
+
+                    const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221);  //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15)
+                    const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221);  //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15)
+
+                    const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221);  //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23)
+                    const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221);  //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23)
+
+                    const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221);  //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31)
+                    const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221);  //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31)
+
+                    // Scale values - Load the wight scale values of block_q4_0x8
+                    const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+
+                    // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
+                    // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
+                    __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
+                    __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
+                    __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
+                    __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
+                    __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
+                    __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
+                    __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
+                    __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
+                    __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
+                    __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
+                    __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
+                    __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
+
+                    // Shuffle pattern one - left side input
+
+                    const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160);  //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
+                    const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160);  //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3)
+
+                    const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160);  //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11)
+                    const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160);  //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11)
+
+                    const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160);  //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19)
+                    const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160);  //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19)
+
+                    const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160);  //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27)
+                    const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160);  //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27)
+
+                    // Shuffle pattern two - left side input
+
+                    const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245);  //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7)
+                    const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245);  //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7)
+
+                    const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245);  //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15)
+                    const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245);  //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15)
+
+                    const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245);  //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23)
+                    const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245);  //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23)
+
+                    const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245);  //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31)
+                    const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245);  //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31)
+
+                    // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
+                    // Resembles MMLAs into 2x2 matrices in ARM Version
+                    __m256i iacc_mat_00_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1));
+                    __m256i iacc_mat_01_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1));
+                    __m256i iacc_mat_10_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1));
+                    __m256i iacc_mat_11_sp1 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1));
+                    __m256i iacc_mat_00_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2));
+                    __m256i iacc_mat_01_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2));
+                    __m256i iacc_mat_10_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2));
+                    __m256i iacc_mat_11_sp2 =
+                        _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2));
+
+                    // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
+                    __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
+                    __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2);
+                    __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2);
+                    __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2);
+
+
+                    // Straighten out to make 4 row vectors
+                    __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204);
+                    __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204);
+                    __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204);
+                    __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204);
+
+                    // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes
+                    const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask);
+
+                    // Multiply with appropiate scales and accumulate
+                    acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]);
+                    acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]);
+                    acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]);
+                    acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]);
+                }
+
+                // Store the accumulated values
+                for (int i = 0; i < 4; i++) {
+                    _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]);
+                }
+            }
+        }
+        return;
+    }
+#elif defined(__riscv_v_intrinsic)
+    if (__riscv_vlenb() >= QK4_0) {
+        const size_t vl = QK4_0;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+                vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
+                for (int l = 0; l < nb; l++) {
+                    const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
+                    const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
+                    const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
+                    const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0);
+                    const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1);
+                    const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0);
+                    const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1);
+
+                    // vector version needs Zvfhmin extension
+                    const float a_scales[4] = {
+                        GGML_FP16_TO_FP32(a_ptr[l].d[0]),
+                        GGML_FP16_TO_FP32(a_ptr[l].d[1]),
+                        GGML_FP16_TO_FP32(a_ptr[l].d[2]),
+                        GGML_FP16_TO_FP32(a_ptr[l].d[3])
+                    };
+                    const float b_scales[8] = {
+                        GGML_FP16_TO_FP32(b_ptr[l].d[0]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[1]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[2]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[3]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[4]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[5]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[6]),
+                        GGML_FP16_TO_FP32(b_ptr[l].d[7])
+                    };
+                    const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
+
+                    const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0];
+                    const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32];
+                    const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64];
+                    const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l0;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l0 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4);
+                        sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4);
+                    }
+
+                    const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8];
+                    const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40];
+                    const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72];
+                    const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l1;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l1 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4);
+                        sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4);
+                    }
+
+                    const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16];
+                    const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48];
+                    const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80];
+                    const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l2;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l2 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4);
+                        sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4);
+                    }
+
+                    const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24];
+                    const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56];
+                    const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88];
+                    const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120];
+                    __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment
+                    vint16m4_t sumi_l3;
+                    {
+                        const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4));
+                        const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4));
+                        const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4));
+                        const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4));
+                        const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2);
+                        const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2);
+                        const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2);
+                        const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2);
+
+                        sumi_l3 = sumi_hi_m;
+                    }
+
+                    {
+                        const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3));
+                        const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
+                        const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
+                        const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
+                        const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
+                        const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
+                        const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
+                        const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
+                        const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
+                        const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
+                        const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
+                        const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
+                        const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
+
+                        const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4);
+                        sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4);
+                    }
+                }
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
+                __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
+            }
+        }
+
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
+                                const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
+                            }
+                            sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+        const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+
+                float32x4_t sumf[4];
+                for (int m = 0; m < 4; m++) {
+                    sumf[m] = vdupq_n_f32(0);
+                }
+
+                for (int l = 0; l < nb; l++) {
+                    float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                    float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
+
+                    int32x4_t sumi_0 = vdupq_n_s32(0);
+                    int32x4_t sumi_1 = vdupq_n_s32(0);
+                    int32x4_t sumi_2 = vdupq_n_s32(0);
+                    int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                    for (int k = 0; k < 4; k++) {
+                        int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                        int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                        uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                        int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                        int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                        sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                        sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                        sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                        sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                    }
+
+                    sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                    sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                    sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                    sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+                }
+
+                for (int m = 0; m < 4; m++) {
+                    vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+                }
+            }
+        }
+        return;
+    }
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    {
+        float sumf[4][4];
+        int sumi;
+
+        for (int y = 0; y < nr / 4; y++) {
+            const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+            for (int x = 0; x < nc / ncols_interleaved; x++) {
+                const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+                }
+                for (int l = 0; l < nb; l++) {
+                    for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                        for (int m = 0; m < 4; m++) {
+                            for (int j = 0; j < ncols_interleaved; j++) {
+                                sumi = 0;
+                                for (int i = 0; i < blocklen; ++i) {
+                                    const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                    const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                    sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                            (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                                }
+                                sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
+                            }
+                        }
+                    }
+                }
+                for (int m = 0; m < 4; m++) {
+                    for (int j = 0; j < ncols_interleaved; j++)
+                        s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
+static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
+    block_q4_0x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_0 * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        const uint64_t xor_mask = 0x8888888888888888ULL;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            uint64_t elems;
+            // Using memcpy to avoid unaligned memory accesses
+            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+            elems ^= xor_mask;
+            memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+        }
+    } else if (blck_size_interleave == 4) {
+        const uint32_t xor_mask = 0x88888888;
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            uint32_t elems;
+            memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
+            elems ^= xor_mask;
+            memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+// interleave 8 block_q4_0s in blocks of blck_size_interleave
+// returns an interleaved block_q4_0x8
+// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
+// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
+static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
+    block_q4_0x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_0 * 4 / blck_size_interleave;
+    const uint64_t xor_mask = 0x8888888888888888ULL;
+
+    for (int i = 0; i < end; ++i) {
+        int src_id = i % 8;
+        int src_offset = (i / 8) * blck_size_interleave;
+        int dst_offset = i * blck_size_interleave;
+
+        uint64_t elems;
+        memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
+        elems ^= xor_mask;
+        memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
+    }
+
+    return out;
+}
+
+static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    constexpr int nrows_interleaved = 4;
+
+    block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
+    const block_q4_0 * src = (const block_q4_0 *)data;
+    block_q4_0 dst_tmp[4];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(interleave_block == 8);
+    constexpr int nrows_interleaved = 8;
+
+    block_q4_0x8 * dst = (block_q4_0x8*)t->data;
+    const block_q4_0 * src = (const block_q4_0*) data;
+    block_q4_0 dst_tmp[8];
+    int nrow = ggml_nrows(t);
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i  = 0; i < nrows_interleaved; i++ ) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_q4_0x8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
+    block_iq4_nlx4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.d[i] = in[i].d;
+    }
+
+    const int end = QK4_NL * 2 / blck_size_interleave;
+
+    // TODO: this branch seems wrong
+    //if (blck_size_interleave == 8) {
+    //    for (int i = 0; i < end; ++i) {
+    //        int src_id = i % 4;
+    //        int src_offset = (i / 4) * blck_size_interleave;
+    //        int dst_offset = i * blck_size_interleave;
+
+    //        // Using memcpy to avoid unaligned memory accesses
+    //        memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+    //    }
+    //} else
+    if (blck_size_interleave == 4) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
+    //GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+    GGML_ASSERT(interleave_block == 4);
+
+    block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
+    const block_iq4_nl * src = (const block_iq4_nl *)data;
+    block_iq4_nl dst_tmp[4];
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK4_0;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+namespace ggml::cpu::aarch64 {
+// repack
+template 
+int repack(struct ggml_tensor *, const void *, size_t);
+
+// TODO: generalise.
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
+}
+
+// TODO: needs to be revisited
+//template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+//    return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
+//}
+
+// gemv
+template 
+void gemv(int, float *, size_t, const void *, const void *, int, int);
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <>
+void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+// gemm
+template 
+void gemm(int, float *, size_t, const void *, const void *, int, int);
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <>
+void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+class tensor_traits_base : public ggml::cpu::tensor_traits {
+  public:
+    virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
+};
+
+template  class tensor_traits : public tensor_traits_base {
+
+    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        // not realy a GGML_TYPE_Q8_0 but same size.
+        switch (op->op) {
+        case GGML_OP_MUL_MAT:
+            size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
+            return true;
+        case GGML_OP_MUL_MAT_ID:
+            size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1]));
+            size = GGML_PAD(size, sizeof(int64_t));  // + padding for next bloc.
+            size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2];
+            return true;
+        default:
+            // GGML_ABORT("fatal error");
+            break;
+        }
+        return false;
+    }
+
+    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
+        switch (op->op) {
+        case GGML_OP_MUL_MAT:
+            forward_mul_mat(params, op);
+            return true;
+        case GGML_OP_MUL_MAT_ID:
+            forward_mul_mat_id(params, op);
+            return true;
+        default:
+            // GGML_ABORT("fatal error");
+            break;
+        }
+        return false;
+    }
+
+    void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
+        const ggml_tensor * src0 = op->src[0];
+        const ggml_tensor * src1 = op->src[1];
+        ggml_tensor *       dst  = op;
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        GGML_ASSERT(ne0 == ne01);
+        GGML_ASSERT(ne1 == ne11);
+        GGML_ASSERT(ne2 == ne12);
+        GGML_ASSERT(ne3 == ne13);
+
+        // dst cannot be transposed or permuted
+        GGML_ASSERT(nb0 == sizeof(float));
+        GGML_ASSERT(nb0 <= nb1);
+        GGML_ASSERT(nb1 <= nb2);
+        GGML_ASSERT(nb2 <= nb3);
+
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
+        // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
+
+        char *       wdata = static_cast(params->wdata);
+        const size_t nbw1  = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+
+        assert(params->wsize >= nbw1 * ne11);
+
+        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
+
+        int64_t i11_processed = 0;
+        for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
+            quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10,
+                              INTER_SIZE);
+        }
+        i11_processed = ne11 - ne11 % 4;
+        for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
+            from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        const void * src1_wdata      = params->wdata;
+        const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+        int64_t      src0_start      = (ith * ne01) / nth;
+        int64_t      src0_end        = ((ith + 1) * ne01) / nth;
+        src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
+        src0_end   = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
+        if (src0_start >= src0_end) {
+            return;
+        }
+
+        // If there are more than three rows in src1, use gemm; otherwise, use gemv.
+        if (ne11 > 3) {
+            gemm(ne00, (float *) ((char *) dst->data) + src0_start, ne01,
+                                                 (const char *) src0->data + src0_start * nb01,
+                                                 (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
+        }
+        for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) {
+            gemv(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01,
+                                                 (const char *) src0->data + src0_start * nb01,
+                                                 (const char *) src1_wdata + (src1_col_stride * iter), 1,
+                                                 src0_end - src0_start);
+        }
+    }
+
+    void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
+        const ggml_tensor * src0 = op->src[0];
+        const ggml_tensor * src1 = op->src[1];
+        const ggml_tensor * ids  = op->src[2];
+        ggml_tensor *       dst  = op;
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float;
+
+        // we don't support permuted src0 or src1
+        GGML_ASSERT(nb00 == ggml_type_size(src0->type));
+        GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+        // dst cannot be transposed or permuted
+        GGML_ASSERT(nb0 == sizeof(float));
+        GGML_ASSERT(nb0 <= nb1);
+        GGML_ASSERT(nb1 <= nb2);
+        GGML_ASSERT(nb2 <= nb3);
+
+        GGML_ASSERT(ne03 == 1);
+        GGML_ASSERT(ne13 == 1);
+        GGML_ASSERT(ne3  == 1);
+
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        // row groups
+        const int n_ids = ids->ne[0]; // n_expert_used
+        const int n_as  = ne02;       // n_expert
+
+        const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        struct mmid_row_mapping {
+            int32_t i1;
+            int32_t i2;
+        };
+
+        GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) +
+                                      n_as * ne12 * sizeof(mmid_row_mapping)));
+
+        auto                      wdata             = (char *) params->wdata;
+        auto                      wdata_src1_end    = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t));
+        int64_t *                 matrix_row_counts = (int64_t *) (wdata_src1_end);                      // [n_as]
+        struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as);  // [n_as][ne12]
+
+        // src1: float32 => block_q8_0
+        for (int64_t i12 = 0; i12 < ne12; ++i12) {
+            for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
+                           (void *)               (wdata + i12 * nbw2 + i11 * nbw1),
+                           ne10);
+            }
+        }
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
+
+        if (ith == 0) {
+            // initialize matrix_row_counts
+            memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
+
+            // group rows by src0 matrix
+            for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+                for (int32_t id = 0; id < n_ids; ++id) {
+                    const int32_t i02 =
+                        *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
+
+                    GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                    MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
+                    matrix_row_counts[i02] += 1;
+                }
+            }
+        }
+
+        ggml_barrier(params->threadpool);
+
+        // compute each matrix multiplication in sequence
+        for (int cur_a = 0; cur_a < n_as; ++cur_a) {
+            const int64_t cne1 = matrix_row_counts[cur_a];
+
+            if (cne1 == 0) {
+                continue;
+            }
+
+            auto src0_cur = (const char *) src0->data + cur_a*nb02;
+
+            //const int64_t nr0 = ne01; // src0 rows
+            const int64_t nr1 = cne1; // src1 rows
+
+            int64_t src0_cur_start = (ith * ne01) / nth;
+            int64_t src0_cur_end   = ((ith + 1) * ne01) / nth;
+            src0_cur_start =
+                (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
+            src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
+
+            if (src0_cur_start >= src0_cur_end) return;
+
+            for (int ir1 = 0; ir1 < nr1; ir1++) {
+                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
+                const int id       = row_mapping.i1; // selected expert index
+
+                const int64_t  i11 = id % ne11;
+                const int64_t  i12 = row_mapping.i2; // row index in src1
+
+                const int64_t  i1 = id;  // selected expert index
+                const int64_t  i2 = i12; // row
+
+                auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
+
+                gemv(
+                        ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start,
+                        ne01,                    src0_cur + src0_cur_start * nb01,
+                        src1_col, 1, src0_cur_end - src0_cur_start);
+            }
+        }
+#undef MMID_MATRIX_ROW
+    }
+
+    int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
+        GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
+                       (int) NB_COLS, (int) INTER_SIZE);
+        return ggml::cpu::aarch64::repack(t, data, data_size);
+    }
+};
+
+// instance for Q4
+static const tensor_traits q4_0_4x4_q8_0;
+static const tensor_traits q4_0_4x8_q8_0;
+static const tensor_traits q4_0_8x8_q8_0;
+
+// instance for IQ4
+static const tensor_traits iq4_nl_4x4_q8_0;
+
+}  // namespace ggml::cpu::aarch64
+
+static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) {
+    if (cur->type == GGML_TYPE_Q4_0) {
+        if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
+            if (cur->ne[1] % 8 == 0) {
+                return &ggml::cpu::aarch64::q4_0_8x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &ggml::cpu::aarch64::q4_0_4x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &ggml::cpu::aarch64::q4_0_4x4_q8_0;
+            }
+        }
+    } else if (cur->type == GGML_TYPE_IQ4_NL) {
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0;
+            }
+        }
+    }
+
+    return nullptr;
+}
+
+static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    tensor->extra = (void *) const_cast(ggml_aarch64_get_optimal_repack_type(tensor));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
+                                                       const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra;
+    auto OK            = tensor_traits->repack(tensor, data, size);
+
+    GGML_ASSERT(OK == 0);
+    GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_AARCH64";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+    if (buffer == nullptr) {
+        return nullptr;
+    }
+
+    buffer->buft              = buft;
+    buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
+    buffer->iface.set_tensor  = ggml_backend_cpu_aarch64_buffer_set_tensor;
+    return buffer;
+}
+
+static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+namespace ggml::cpu::aarch64 {
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
+    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        if (    op->op == GGML_OP_MUL_MAT &&
+                op->src[0]->buffer &&
+                (ggml_n_dims(op->src[0]) == 2) &&
+                op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() &&
+                ggml_aarch64_get_optimal_repack_type(op->src[0])
+                ) {
+            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+                return false;
+            }
+            if (op->src[1]->type == GGML_TYPE_F32) {
+                return true;
+            }
+            //if (op->src[1]->type == GGML_TYPE_Q8_0) {
+            //    return true;
+            //}
+            // may be possible if Q8_0 packed...
+        } else if (op->op == GGML_OP_MUL_MAT_ID
+                && op->src[0]->buffer
+                && (ggml_n_dims(op->src[0]) == 3)
+                && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()
+                && ggml_aarch64_get_optimal_repack_type(op->src[0])
+                ) {
+            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+                return false;
+            }
+            if (op->src[1]->type == GGML_TYPE_F32) {
+                return true;
+            }
+            //if (op->src[1]->type == GGML_TYPE_Q8_0) {
+            //    return true;
+            //}
+        }
+        return false;
+    }
+
+    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
+        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
+            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) {
+                return (ggml::cpu::tensor_traits *) op->src[0]->extra;
+            }
+        }
+        return nullptr;
+    }
+};
+}  // namespace ggml::cpu::aarch64
+
+ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = {
+        /* .iface    = */ {
+                           /* .get_name         = */ ggml_backend_cpu_aarch64_buffer_type_get_name,
+                           /* .alloc_buffer     = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer,
+                           /* .get_alignment    = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment,
+                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
+                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes
+                           /* .is_host          = */ nullptr,
+                           },
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(),
+    };
+
+    return &ggml_backend_cpu_buffer_type_aarch64;
+}
diff --git a/llama/ggml-cpu-aarch64.h b/llama/ggml-cpu-aarch64.h
new file mode 100644
index 000000000..14320735c
--- /dev/null
+++ b/llama/ggml-cpu-aarch64.h
@@ -0,0 +1,34 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml-cpu-traits.h"
+#include "ggml.h"
+
+// GGML internal header
+
+ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void);
diff --git a/llama/ggml-cpu-impl.h b/llama/ggml-cpu-impl.h
new file mode 100644
index 000000000..54dc108cc
--- /dev/null
+++ b/llama/ggml-cpu-impl.h
@@ -0,0 +1,412 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+// GGML CPU internal header
+
+#include "ggml.h"
+#include "ggml-impl.h"
+#include  // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
+//#include 
+#include 
+#include  // memcpy
+#include    // fabsf
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_compute_params {
+    // ith = thread index, nth = number of threads
+    int ith, nth;
+
+    // work buffer for all threads
+    size_t wsize;
+    void * wdata;
+
+    struct ggml_threadpool * threadpool;
+};
+
+
+#if defined(_MSC_VER)
+
+#define m512bh(p) p
+#define m512i(p) p
+
+#else
+
+#define m512bh(p) (__m512bh)(p)
+#define m512i(p) (__m512i)(p)
+
+#endif
+
+// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
+#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __FMA__
+#define __FMA__
+#endif
+#ifndef __F16C__
+#define __F16C__
+#endif
+#endif
+
+// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
+#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
+#ifndef __SSE3__
+#define __SSE3__
+#endif
+#ifndef __SSSE3__
+#define __SSSE3__
+#endif
+#endif
+
+#if defined(__ARM_FEATURE_SVE)
+#include 
+#include 
+#endif
+
+// 16-bit float
+// on Arm, we use __fp16
+// on x86, we use uint16_t
+#if defined(__ARM_NEON)
+
+// if YCM cannot find , make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include 
+
+#ifdef _MSC_VER
+
+typedef uint16_t ggml_fp16_internal_t;
+
+#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
+
+#else
+
+typedef __fp16 ggml_fp16_internal_t;
+
+#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
+
+#endif // _MSC_VER
+
+#if !defined(__aarch64__)
+
+// 32-bit ARM compatibility
+
+// vaddlvq_s16
+// vpaddq_s16
+// vpaddq_s32
+// vaddvq_s32
+// vaddvq_f32
+// vmaxvq_f32
+// vcvtnq_s32_f32
+// vzip1_u8
+// vzip2_u8
+
+inline static int32_t vaddlvq_s16(int16x8_t v) {
+    int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
+    return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
+}
+
+inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
+    int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
+    int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
+    return vcombine_s16(a0, b0);
+}
+
+inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
+    int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
+    int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
+    return vcombine_s32(a0, b0);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline static float vmaxvq_f32(float32x4_t v) {
+    return
+        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
+    int32x4_t res;
+
+    res[0] = roundf(vgetq_lane_f32(v, 0));
+    res[1] = roundf(vgetq_lane_f32(v, 1));
+    res[2] = roundf(vgetq_lane_f32(v, 2));
+    res[3] = roundf(vgetq_lane_f32(v, 3));
+
+    return res;
+}
+
+inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
+    uint8x8_t res;
+
+    res[0] = a[0]; res[1] = b[0];
+    res[2] = a[1]; res[3] = b[1];
+    res[4] = a[2]; res[5] = b[2];
+    res[6] = a[3]; res[7] = b[3];
+
+    return res;
+}
+
+inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
+    uint8x8_t res;
+
+    res[0] = a[4]; res[1] = b[4];
+    res[2] = a[5]; res[3] = b[5];
+    res[4] = a[6]; res[5] = b[6];
+    res[6] = a[7]; res[7] = b[7];
+
+    return res;
+}
+
+// vld1q_s16_x2
+// vld1q_u8_x2
+// vld1q_u8_x4
+// vld1q_s8_x2
+// vld1q_s8_x4
+// TODO: double-check these work correctly
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vld1q_s16(ptr + 0);
+    res.val[1] = vld1q_s16(ptr + 8);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vld1q_u8(ptr + 0);
+    res.val[1] = vld1q_u8(ptr + 16);
+    res.val[2] = vld1q_u8(ptr + 32);
+    res.val[3] = vld1q_u8(ptr + 48);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x2_t {
+    int8x16_t val[2];
+} ggml_int8x16x2_t;
+
+inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
+    ggml_int8x16x2_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vld1q_s8(ptr + 0);
+    res.val[1] = vld1q_s8(ptr + 16);
+    res.val[2] = vld1q_s8(ptr + 32);
+    res.val[3] = vld1q_s8(ptr + 48);
+
+    return res;
+}
+
+// NOTE: not tested
+inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
+    int8x16_t res;
+
+    res[ 0] = a[b[ 0]];
+    res[ 1] = a[b[ 1]];
+    res[ 2] = a[b[ 2]];
+    res[ 3] = a[b[ 3]];
+    res[ 4] = a[b[ 4]];
+    res[ 5] = a[b[ 5]];
+    res[ 6] = a[b[ 6]];
+    res[ 7] = a[b[ 7]];
+    res[ 8] = a[b[ 8]];
+    res[ 9] = a[b[ 9]];
+    res[10] = a[b[10]];
+    res[11] = a[b[11]];
+    res[12] = a[b[12]];
+    res[13] = a[b[13]];
+    res[14] = a[b[14]];
+    res[15] = a[b[15]];
+
+    return res;
+}
+
+// NOTE: not tested
+inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
+    uint8x16_t res;
+
+    res[ 0] = a[b[ 0]];
+    res[ 1] = a[b[ 1]];
+    res[ 2] = a[b[ 2]];
+    res[ 3] = a[b[ 3]];
+    res[ 4] = a[b[ 4]];
+    res[ 5] = a[b[ 5]];
+    res[ 6] = a[b[ 6]];
+    res[ 7] = a[b[ 7]];
+    res[ 8] = a[b[ 8]];
+    res[ 9] = a[b[ 9]];
+    res[10] = a[b[10]];
+    res[11] = a[b[11]];
+    res[12] = a[b[12]];
+    res[13] = a[b[13]];
+    res[14] = a[b[14]];
+    res[15] = a[b[15]];
+
+    return res;
+}
+
+#else
+
+#define ggml_int16x8x2_t  int16x8x2_t
+#define ggml_uint8x16x2_t uint8x16x2_t
+#define ggml_uint8x16x4_t uint8x16x4_t
+#define ggml_int8x16x2_t  int8x16x2_t
+#define ggml_int8x16x4_t  int8x16x4_t
+
+#define ggml_vld1q_s16_x2 vld1q_s16_x2
+#define ggml_vld1q_u8_x2  vld1q_u8_x2
+#define ggml_vld1q_u8_x4  vld1q_u8_x4
+#define ggml_vld1q_s8_x2  vld1q_s8_x2
+#define ggml_vld1q_s8_x4  vld1q_s8_x4
+#define ggml_vqtbl1q_s8   vqtbl1q_s8
+#define ggml_vqtbl1q_u8   vqtbl1q_u8
+
+#endif // !defined(__aarch64__)
+
+#if !defined(__ARM_FEATURE_DOTPROD)
+
+inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
+    const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
+    const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
+
+    return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
+}
+
+#else
+
+#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
+
+#endif // !defined(__ARM_FEATURE_DOTPROD)
+
+#endif // defined(__ARM_NEON)
+
+#ifdef __wasm_simd128__
+#include 
+#else
+#ifdef __POWER9_VECTOR__
+#include 
+#undef bool
+#define bool _Bool
+#else
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include 
+#else
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
+#if !defined(__riscv)
+#include 
+#endif
+#endif
+#endif
+#endif
+#endif
+
+#ifdef __riscv_v_intrinsic
+#include 
+#endif
+
+#if defined(__loongarch64)
+#if defined(__loongarch_asx)
+#include 
+#endif
+#if defined(__loongarch_sx)
+#include 
+#endif
+#endif
+
+#if defined(__loongarch_asx)
+
+typedef union {
+    int32_t i;
+    float f;
+} ft_union;
+
+/* float type data load instructions */
+static __m128 __lsx_vreplfr2vr_s(float val) {
+    ft_union fi_tmpval = {.f = val};
+    return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
+}
+
+static __m256 __lasx_xvreplfr2vr_s(float val) {
+    ft_union fi_tmpval = {.f = val};
+    return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
+}
+#endif
+
+// TODO: move to ggml-threading
+void ggml_barrier(struct ggml_threadpool * tp);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml-cpu-quants.c b/llama/ggml-cpu-quants.c
new file mode 100644
index 000000000..a8288deca
--- /dev/null
+++ b/llama/ggml-cpu-quants.c
@@ -0,0 +1,10865 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-cpu-quants.h"
+#include "ggml-impl.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-cpu.h"
+
+#include 
+#include 
+#include 
+#include 
+#include  // for qsort
+#include   // for GGML_ASSERT
+
+#define GROUP_MAX_EPS 1e-15f
+#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
+#define GROUP_MAX_EPS_IQ2_S 1e-8f
+#define GROUP_MAX_EPS_IQ1_M 1e-7f
+#define GROUP_MAX_EPS_IQ1_S 1e-12f
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid warnings for hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// some compilers don't provide _mm256_set_m128i, e.g. gcc 7
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = _mm_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = _mm_sign_epi8(y, x);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = _mm_maddubs_epi16(ax, sy);
+    const __m128i ones = _mm_set1_epi16(1);
+    return _mm_madd_epi16(ones, dot);
+}
+
+#if __AVX__ || __AVX2__ || __AVX512F__
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = _mm256_extractf128_ps(x, 1);
+    res = _mm_add_ps(res, _mm256_castps256_ps128(x));
+    res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+    res = _mm_add_ss(res, _mm_movehdup_ps(res));
+    return _mm_cvtss_f32(res);
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+    const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
+    const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
+    const __m128i sum64 = _mm_add_epi32(hi64, sum128);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    const __m128i hi64 = _mm_unpackhi_epi64(a, a);
+    const __m128i sum64 = _mm_add_epi32(hi64, a);
+    const __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+    return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
+}
+
+#if defined(__AVX2__) || defined(__AVX512F__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = _mm256_set_epi64x(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+    __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask);
+    const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytes = _mm256_or_si256(bytes, bit_mask);
+    return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i lowMask = _mm256_set1_epi8( 0xF );
+    return _mm256_and_si256(lowMask, bytes);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    const __m256i ones = _mm256_set1_epi16(1);
+    const __m256i summed_pairs = _mm256_madd_epi16(ones, x);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#elif defined(__AVXVNNI__)
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+    return sum_i16_pairs_float(dot);
+#endif
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
+static inline __m128i packNibbles( __m256i bytes )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+#if __AVX512F__
+    const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
+    bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
+    return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
+#else
+    const __m256i lowByte = _mm256_set1_epi16( 0xFF );
+    __m256i high = _mm256_andnot_si256( lowByte, bytes );
+    __m256i low = _mm256_and_si256( lowByte, bytes );
+    high = _mm256_srli_epi16( high, 4 );
+    bytes = _mm256_or_si256( low, high );
+
+    // Compress uint16_t lanes into bytes
+    __m128i r0 = _mm256_castsi256_si128( bytes );
+    __m128i r1 = _mm256_extracti128_si256( bytes, 1 );
+    return _mm_packus_epi16( r0, r1 );
+#endif
+}
+#elif defined(__AVX__)
+static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
+{
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m128i lowByte = _mm_set1_epi16( 0xFF );
+    __m128i high = _mm_andnot_si128( lowByte, bytes1 );
+    __m128i low = _mm_and_si128( lowByte, bytes1 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes1 = _mm_or_si128( low, high );
+    high = _mm_andnot_si128( lowByte, bytes2 );
+    low = _mm_and_si128( lowByte, bytes2 );
+    high = _mm_srli_epi16( high, 4 );
+    bytes2 = _mm_or_si128( low, high );
+
+    return _mm_packus_epi16( bytes1, bytes2);
+}
+
+static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
+    const __m128i ax = _mm_sign_epi8(x, x);
+    const __m128i sy = _mm_sign_epi8(y, x);
+    return _mm_maddubs_epi16(ax, sy);
+}
+
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+    const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+    __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+    __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+    const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+    bytesl = _mm_or_si128(bytesl, bit_mask);
+    bytesh = _mm_or_si128(bytesh, bit_mask);
+    bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+    bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+    return MM256_SET_M128I(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+    // Load 16 bytes from memory
+    __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+    __m128i tmph = _mm_srli_epi16(tmpl, 4);
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    tmpl = _mm_and_si128(lowMask, tmpl);
+    tmph = _mm_and_si128(lowMask, tmph);
+    return MM256_SET_M128I(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+    const __m128i ones = _mm_set1_epi16(1);
+    const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+    const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+    const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl);
+    return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+    const __m128i xl = _mm256_castsi256_si128(x);
+    const __m128i xh = _mm256_extractf128_si256(x, 1);
+    const __m128i yl = _mm256_castsi256_si128(y);
+    const __m128i yh = _mm256_extractf128_si256(y, 1);
+    // Get absolute values of x vectors
+    const __m128i axl = _mm_sign_epi8(xl, xl);
+    const __m128i axh = _mm_sign_epi8(xh, xh);
+    // Sign the values of the y vectors
+    const __m128i syl = _mm_sign_epi8(yl, xl);
+    const __m128i syh = _mm_sign_epi8(yh, xh);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
+// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
+static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
+                                           const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
+    const __m128i mone = _mm_set1_epi16(1);
+
+    const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
+    const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
+    const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
+    const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
+    const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
+    const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
+    const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
+    const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
+    const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1);
+    const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1);
+    return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1));
+}
+
+// quad fp16 delta calculation
+static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
+    // GGML_FP16_TO_FP32 is faster than Intel F16C
+    return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)),
+                           _mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0)));
+}
+#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =_mm_hadd_ps(a, b);
+    __m128 res_1 =_mm_hadd_ps(c, d);
+    __m128 res =_mm_hadd_ps(res_0, res_1);
+    res =_mm_hadd_ps(res, res);
+    res =_mm_hadd_ps(res, res);
+
+    return _mm_cvtss_f32(res);
+}
+#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
+
+#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
+#define B1(c,s,n)  0x ## n ## c ,  0x ## n ## s
+#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s)
+#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s)
+#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s)
+#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s)
+#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s)
+#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
+#define B8(c,s  ) B7(c,s,     c), B7(c,s,     s)
+
+// precomputed tables for expanding 8bits to 8 bytes:
+static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
+static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
+#endif
+
+#if defined(__loongarch_asx)
+
+#ifdef __clang__
+#define VREGS_PREFIX "$vr"
+#define XREGS_PREFIX "$xr"
+#else // GCC
+#define VREGS_PREFIX "$f"
+#define XREGS_PREFIX "$f"
+#endif
+#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
+// Convert __m128i to __m256i
+static inline __m256i ____m256i(__m128i in) {
+    __m256i out = __lasx_xvldi(0);
+    __asm__ volatile (
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " XREGS_PREFIX"\\i    \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[in], " VREGS_PREFIX "\\j  \n\t"
+        "    xvpermi.q $xr\\i, $xr\\j, 0x20  \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        : [out] "+f" (out) : [in] "f" (in)
+    );
+    return out;
+}
+// Convert two __m128i to __m256i
+static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
+    __m256i out;
+    __asm__ volatile (
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[hi], " VREGS_PREFIX "\\i    \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[lo], " VREGS_PREFIX "\\j  \n\t"
+        "    xvpermi.q $xr\\i, $xr\\j, 0x20  \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        ".ifnc %[out], %[hi]                 \n\t"
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " XREGS_PREFIX "\\i   \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[hi], " VREGS_PREFIX "\\j  \n\t"
+        "    xvori.b $xr\\i, $xr\\j, 0       \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        ".endif                              \n\t"
+        : [out] "=f" (out), [hi] "+f" (inhi)
+        : [lo] "f" (inlo)
+    );
+    return out;
+}
+// Convert __m256i low part to __m128i
+static inline __m128i lasx_extracti128_lo(__m256i in) {
+    __m128i out;
+    __asm__ volatile (
+        ".ifnc %[out], %[in]                 \n\t"
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " VREGS_PREFIX "\\i   \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[in], " XREGS_PREFIX "\\j  \n\t"
+        "    vori.b $vr\\i, $vr\\j, 0        \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        ".endif                              \n\t"
+        : [out] "=f" (out) : [in] "f" (in)
+    );
+    return out;
+}
+// Convert __m256i high part to __m128i
+static inline __m128i lasx_extracti128_hi(__m256i in) {
+    __m128i out;
+    __asm__ volatile (
+        ".irp i," __ALL_REGS                "\n\t"
+        " .ifc %[out], " VREGS_PREFIX "\\i   \n\t"
+        "  .irp j," __ALL_REGS              "\n\t"
+        "   .ifc %[in], " XREGS_PREFIX "\\j  \n\t"
+        "    xvpermi.q $xr\\i, $xr\\j, 0x11  \n\t"
+        "   .endif                           \n\t"
+        "  .endr                             \n\t"
+        " .endif                             \n\t"
+        ".endr                               \n\t"
+        : [out] "=f" (out) : [in] "f" (in)
+    );
+    return out;
+}
+
+static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) {
+    v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7};
+    return (__m256i)__ret;
+}
+
+static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
+    v4i32 __ret = {d, c, b, a};
+    return (__m128i)__ret;
+}
+
+static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
+    v4i64 __ret = {d, c, b, a};
+    return (__m256i)__ret;
+}
+
+static __m256i lasx_insertf128( __m128i x, __m128i y) {
+    return lasx_set_q(x, y);
+}
+
+static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
+    __m128i mask_f, zero, tmp0, tmp2, mask;
+    int f = 0x8f;
+    mask_f = __lsx_vreplgr2vr_b(f);
+    zero = __lsx_vldi(0);
+    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
+    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
+    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
+    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
+    return __lsx_vshuf_b(a, zero, tmp2);
+}
+
+static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
+    __m256i mask_f, zero, tmp0, tmp2, mask;
+    int f = 0x8f;
+    mask_f = __lasx_xvreplgr2vr_b(f);
+    zero = __lasx_xvldi(0);
+    tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits
+    tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
+    mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask
+    tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones
+    return __lasx_xvshuf_b(a, zero, tmp2);
+}
+
+static __m256i lasx_extu8_16(__m128i a) {
+    __m128i zero = __lsx_vldi(0);
+    __m128i vlo = __lsx_vilvl_b(zero, a);
+    __m128i vhi = __lsx_vilvh_b(zero, a);
+    return lasx_set_q(vhi, vlo);
+}
+
+static __m256i lasx_ext8_16(__m128i a) {
+     __m128i sign = __lsx_vslti_b(a, 0);
+     __m128i vlo = __lsx_vilvl_b(sign, a);
+     __m128i vhi = __lsx_vilvh_b(sign, a);
+     return lasx_set_q(vhi, vlo);
+}
+
+static __m256i lasx_ext16_32(__m128i a) {
+    __m256i tmp1;
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
+    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
+    return tmp1;
+}
+
+static __m128i lasx_extracti128( __m256i a, int pos) {
+    __m128i ret;
+    if( pos == 0)
+    {
+       ret = lasx_extracti128_lo(a);
+    } else {
+       ret = lasx_extracti128_hi(a);
+    }
+    return ret;
+}
+
+static __m128 lasx_extractf128( __m256 a, int pos) {
+    __m128 ret;
+    if( pos == 0)
+    {
+       ret = (__m128)lasx_extracti128_lo((__m256i)a);
+    } else {
+       ret = (__m128)lasx_extracti128_hi((__m256i)a);
+    }
+    return ret;
+}
+
+static __m128i lsx_hadd_h(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_h(b, a);
+    __m128i tmp2 = __lsx_vpickod_h(b, a);
+    return __lsx_vadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_hadd_w(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_w(b, a);
+    __m128i tmp2 = __lsx_vpickod_w(b, a);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128 lsx_hadd_s(__m128 a, __m128 b) {
+    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
+    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
+
+    return __lsx_vfadd_s(tmp1, tmp2);
+}
+
+static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
+    __m256i tmp1, tmp2;
+    tmp1 = __lasx_xvmulwev_h_b(a, b);
+    tmp2 = __lasx_xvmulwod_h_b(a, b);
+    return __lasx_xvsadd_h(tmp1, tmp2);
+}
+
+static __m256i lasx_madd_h(__m256i a, __m256i b) {
+    __m256i tmp1, tmp2;
+    tmp1 = __lasx_xvmulwev_w_h(a, b);
+    tmp2 = __lasx_xvmulwod_w_h(a, b);
+    return __lasx_xvadd_w(tmp1, tmp2);
+}
+
+static __m256i lasx_packs_w(__m256i a, __m256i b) {
+    __m256i tmp, tmp1;
+    tmp = __lasx_xvsat_w(a, 15);
+    tmp1 = __lasx_xvsat_w(b, 15);
+    return __lasx_xvpickev_h(tmp1, tmp);
+}
+
+static __m256i lasx_packs_h(__m256i a, __m256i b) {
+    __m256i tmp, tmp1;
+    tmp = __lasx_xvsat_h(a, 7);
+    tmp1 = __lasx_xvsat_h(b, 7);
+    return __lasx_xvpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packs_w(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_w(a, 15);
+    tmp1 = __lsx_vsat_w(b, 15);
+    return __lsx_vpickev_h(tmp1, tmp);
+}
+
+static __m128i lsx_packs_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_h(a, 7);
+    tmp1 = __lsx_vsat_h(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packus_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_hu(a, 7);
+    tmp1 = __lsx_vsat_hu(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+
+static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_h_b(a, b);
+    tmp2 = __lsx_vmulwod_h_b(a, b);
+    return __lsx_vsadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_madd_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_w_h(a, b);
+    tmp2 = __lsx_vmulwod_w_h(a, b);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+// multiply int8_t, add results pairwise twice
+static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
+    // Get absolute values of x vectors
+    const __m128i ax = __lsx_vsigncov_b(x, x);
+    // Sign the values of the y vectors
+    const __m128i sy = __lsx_vsigncov_b(x, y);
+    // Perform multiplication and create 16-bit values
+    const __m128i dot = lsx_maddubs_h(ax, sy);
+    const __m128i ones = __lsx_vreplgr2vr_h(1);
+    return lsx_madd_h(ones, dot);
+}
+
+// horizontally add 8 floats
+static inline float hsum_float_8(const __m256 x) {
+    __m128 res = lasx_extractf128(x, 1);
+    ft_union tmp;
+    res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
+    res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
+    res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
+    tmp.i = __lsx_vpickve2gr_w(res, 0);
+    return tmp.f;
+}
+
+// horizontally add 8 int32_t
+static inline int hsum_i32_8(const __m256i a) {
+
+    __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11);
+    __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00);
+
+    __m128i  tmp1_128 = lasx_extracti128_lo(tmp1);
+    __m128i  tmp2_128 = lasx_extracti128_lo(tmp2);
+
+    __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128);
+
+    __m128i ev = __lsx_vpickev_w(sum128, sum128);
+    __m128i od = __lsx_vpickod_w(sum128, sum128);
+    __m128i sum64 = __lsx_vadd_w(ev, od);
+
+    int sum64_1, sum64_2;
+    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
+    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
+
+    return  sum64_1 + sum64_2;
+}
+
+// horizontally add 4 int32_t
+static inline int hsum_i32_4(const __m128i a) {
+    __m128i ev = __lsx_vpickev_w(a, a);
+    __m128i od = __lsx_vpickod_w(a, a);
+    __m128i sum64 = __lsx_vadd_w(ev, od);
+
+    int sum64_1, sum64_2;
+    sum64_1 = __lsx_vpickve2gr_w(sum64, 0);
+    sum64_2 = __lsx_vpickve2gr_w(sum64, 1);
+
+    return  sum64_1 + sum64_2;
+}
+
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+
+    uint32_t x32;
+    memcpy(&x32, x, sizeof(uint32_t));
+    const __m256i shuf_mask = lasx_set_d(
+            0x0303030303030303, 0x0202020202020202,
+            0x0101010101010101, 0x0000000000000000);
+
+    __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask);
+    const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe);
+    bytes = __lasx_xvor_v(bytes, bit_mask);
+    return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1));
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
+    const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
+    __m128i hi = __lsx_vsrli_h(lo, 4);
+    return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m256i x) {
+    __m256i v = __lasx_xvpackod_h(x, x);
+    __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
+    return __lasx_xvffint_s_w(summed_pairs);
+}
+
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    // Perform multiplication and create 16-bit values
+    const __m256i dot = lasx_maddubs_h(ax, sy);
+    return sum_i16_pairs_float(dot);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+
+    // Get absolute values of x vectors
+    const __m256i ax = __lasx_xvsigncov_b(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = __lasx_xvsigncov_b(x, y);
+
+    return mul_sum_us8_pairs_float(ax, sy);
+}
+
+static inline __m128i packNibbles( __m256i bytes ) {
+    // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
+    const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF);
+     __m256i high = __lasx_xvandn_v(lowByte, bytes);
+    __m256i low = __lasx_xvand_v(lowByte, bytes);
+    high = __lasx_xvsrli_h(high, 4);
+    bytes = __lasx_xvor_v(low, high);
+    // Compress uint16_t lanes into bytes
+    __m128i *r0 = (__m128i *)&bytes;
+    __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11);
+    __m128i *r1 = (__m128i *)&tmp_h128;
+
+    __m128i zero = __lsx_vldi(0);
+    __m128i tmp, tmp2, tmp3;
+
+    tmp = __lsx_vmax_h(zero, *r0);
+    tmp2 = __lsx_vsat_hu(tmp, 7);
+
+    tmp = __lsx_vmax_h(zero, *r1);
+    tmp3 = __lsx_vsat_hu(tmp, 7);
+    return  __lsx_vpickev_b(tmp3, tmp2);
+}
+#endif  //__loongarch_asx
+
+void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q4_0_ref(x, y, k);
+}
+
+void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q4_1_ref(x, y, k);
+}
+
+void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q5_0_ref(x, y, k);
+}
+
+void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q5_1_ref(x, y, k);
+}
+
+void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(QK8_0 == 32);
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+        }
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+        }
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float maxScalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = maxScalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_0);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_0, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0f, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    for (int i = 0; i < nb; i++) {
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+        vector signed int vi[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+        const vector float vid = vec_splats(id);
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const vector float v  = vec_round(vec_mul(srcv[j], vid));
+            vi[j] = vec_cts(v, 0);
+        }
+        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);
+        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+    }
+
+#elif defined(__loongarch_asx)
+    for (int i = 0; i < nb; i++) {
+        ft_union fi;
+        __m256 v0 = (__m256)__lasx_xvld( x , 0);
+        __m256 v1 = (__m256)__lasx_xvld( x , 32);
+        __m256 v2 = (__m256)__lasx_xvld( x , 64);
+        __m256 v3 = (__m256)__lasx_xvld( x , 96);
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
+        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
+
+        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) );
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
+        __m128 tmp = max4;
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
+        fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
+        const float max_scalar = fi.f;
+
+        // Quantize these floats
+        const float d = max_scalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+        const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id );
+
+        // Apply the multiplier
+        v0 = __lasx_xvfmul_s( v0, mul );
+        v1 = __lasx_xvfmul_s( v1, mul );
+        v2 = __lasx_xvfmul_s( v2, mul );
+        v3 = __lasx_xvfmul_s( v3, mul );
+
+        // Round to nearest integer
+        __m256i i0 = __lasx_xvftintrne_w_s( v0 );
+        __m256i i1 = __lasx_xvftintrne_w_s( v1 );
+        __m256i i2 = __lasx_xvftintrne_w_s( v2 );
+        __m256i i3 = __lasx_xvftintrne_w_s( v3 );
+
+        __m128i ni0 = lasx_extracti128( i0, 0 );
+        __m128i ni1 = lasx_extracti128( i0, 1);
+        __m128i ni2 = lasx_extracti128( i1, 0);
+        __m128i ni3 = lasx_extracti128( i1, 1);
+        __m128i ni4 = lasx_extracti128( i2, 0);
+        __m128i ni5 = lasx_extracti128( i2, 1);
+        __m128i ni6 = lasx_extracti128( i3, 0);
+        __m128i ni7 = lasx_extracti128( i3, 1);
+
+        // Convert int32 to int16
+        ni0 = lsx_packs_w( ni0, ni1 );
+        ni2 = lsx_packs_w( ni2, ni3 );
+        ni4 = lsx_packs_w( ni4, ni5 );
+        ni6 = lsx_packs_w( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = lsx_packs_h( ni0, ni2 );
+        ni4 = lsx_packs_h( ni4, ni6 );
+
+        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
+        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
+
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_0_ref(x, y, k);
+#endif
+}
+
+void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    for (int i = 0; i < nb; i++) {
+        float32x4_t srcv [8];
+        float32x4_t asrcv[8];
+        float32x4_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vld1q_f32(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = vmaxvq_f32(amaxv[0]);
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        int32x4_t accv = vdupq_n_s32(0);
+
+        for (int j = 0; j < 8; j++) {
+            const float32x4_t v  = vmulq_n_f32(srcv[j], id);
+            const int32x4_t   vi = vcvtnq_s32_f32(v);
+
+            y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0);
+            y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1);
+            y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2);
+            y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
+
+            accv = vaddq_s32(accv, vi);
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
+    }
+#elif defined(__wasm_simd128__)
+    for (int i = 0; i < nb; i++) {
+        v128_t srcv [8];
+        v128_t asrcv[8];
+        v128_t amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = wasm_v128_load(x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
+                                   wasm_f32x4_extract_lane(amaxv[0], 1)),
+                               MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
+                                   wasm_f32x4_extract_lane(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        v128_t accv = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < 8; j++) {
+            const v128_t v  = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
+            const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);
+
+            y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
+            y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
+            y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
+            y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
+
+            accv = wasm_i32x4_add(accv, vi);
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(
+                d * (wasm_i32x4_extract_lane(accv, 0) +
+                     wasm_i32x4_extract_lane(accv, 1) +
+                     wasm_i32x4_extract_lane(accv, 2) +
+                     wasm_i32x4_extract_lane(accv, 3)));
+    }
+#elif defined(__AVX2__) || defined(__AVX__)
+    for (int i = 0; i < nb; i++) {
+        // Load elements into 4 AVX vectors
+        __m256 v0 = _mm256_loadu_ps( x );
+        __m256 v1 = _mm256_loadu_ps( x + 8 );
+        __m256 v2 = _mm256_loadu_ps( x + 16 );
+        __m256 v3 = _mm256_loadu_ps( x + 24 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 signBit = _mm256_set1_ps( -0.0f );
+        __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+        maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+        __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+        max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+        max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+        const float max_scalar = _mm_cvtss_f32( max4 );
+
+        // Quantize these floats
+        const float d = max_scalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+        const __m256 mul = _mm256_set1_ps( id );
+
+        // Apply the multiplier
+        v0 = _mm256_mul_ps( v0, mul );
+        v1 = _mm256_mul_ps( v1, mul );
+        v2 = _mm256_mul_ps( v2, mul );
+        v3 = _mm256_mul_ps( v3, mul );
+
+        // Round to nearest integer
+        v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+        v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+        v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+        v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+        // Convert floats to integers
+        __m256i i0 = _mm256_cvtps_epi32( v0 );
+        __m256i i1 = _mm256_cvtps_epi32( v1 );
+        __m256i i2 = _mm256_cvtps_epi32( v2 );
+        __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+        // Compute the sum of the quants and set y[i].s
+        y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))));
+
+        // Convert int32 to int16
+        i0 = _mm256_packs_epi32( i0, i1 );	// 0, 1, 2, 3,  8, 9, 10, 11,  4, 5, 6, 7, 12, 13, 14, 15
+        i2 = _mm256_packs_epi32( i2, i3 );	// 16, 17, 18, 19,  24, 25, 26, 27,  20, 21, 22, 23, 28, 29, 30, 31
+                                            // Convert int16 to int8
+        i0 = _mm256_packs_epi16( i0, i2 );	// 0, 1, 2, 3,  8, 9, 10, 11,  16, 17, 18, 19,  24, 25, 26, 27,  4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+        // We got our precious signed bytes, but the order is now wrong
+        // These AVX2 pack instructions process 16-byte pieces independently
+        // The following instruction is fixing the order
+        const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+        i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+        _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+        // Since we don't have in AVX some necessary functions,
+        // we split the registers in half and call AVX2 analogs from SSE
+        __m128i ni0 = _mm256_castsi256_si128( i0 );
+        __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+        __m128i ni2 = _mm256_castsi256_si128( i1 );
+        __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+        __m128i ni4 = _mm256_castsi256_si128( i2 );
+        __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+        __m128i ni6 = _mm256_castsi256_si128( i3 );
+        __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3));
+        const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7));
+        y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1)));
+
+        // Convert int32 to int16
+        ni0 = _mm_packs_epi32( ni0, ni1 );
+        ni2 = _mm_packs_epi32( ni2, ni3 );
+        ni4 = _mm_packs_epi32( ni4, ni5 );
+        ni6 = _mm_packs_epi32( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = _mm_packs_epi16( ni0, ni2 );
+        ni4 = _mm_packs_epi16( ni4, ni6 );
+
+        _mm_storeu_si128((__m128i *)(y[i].qs +  0), ni0);
+        _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+    }
+#elif defined(__riscv_v_intrinsic)
+
+    size_t vl = __riscv_vsetvl_e32m4(QK8_1);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vfloat32m4_t v_x   = __riscv_vle32_v_f32m4(x+i*QK8_1, vl);
+
+        vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl);
+        vfloat32m1_t tmp   = __riscv_vfmv_v_f_f32m1(0.0, vl);
+        vfloat32m1_t vmax  = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl);
+        float amax = __riscv_vfmv_f_s_f32m1_f32(vmax);
+
+        const float d  = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl);
+
+        // convert to integer
+        vint16m2_t   vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl);
+        vint8m1_t    vs = __riscv_vncvt_x_x_w_i8m1(vi, vl);
+
+        // store result
+        __riscv_vse8_v_i8m1(y[i].qs , vs, vl);
+
+        // compute sum for y[i].s
+        vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl);
+        vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl);
+
+        // set y[i].s
+        int sum = __riscv_vmv_x_s_i16m1_i16(vwrs);
+        y[i].s = GGML_FP32_TO_FP16(sum*d);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    for (int i = 0; i < nb; i++) {
+        vector float srcv [8];
+        vector float asrcv[8];
+        vector float amaxv[8];
+        vector signed int vi[8];
+
+        for (int j = 0; j < 8; j++) srcv[j]  = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+        const vector float vid = vec_splats(id);
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        vector int accv = vec_splats(0);
+
+        for (int j = 0; j < 8; j++) {
+            const vector float v  = vec_round(vec_mul(srcv[j], vid));
+            vi[j] = vec_cts(v, 0);
+
+            accv = vec_add(accv, vi[j]);
+        }
+        vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])),  0, &y[i].qs[0]);
+        vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]);
+
+        accv = vec_add(accv, vec_sld(accv, accv, 4));
+        accv = vec_add(accv, vec_sld(accv, accv, 8));
+        y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0));
+    }
+
+#elif defined(__loongarch_asx)
+    for (int i = 0; i < nb; i++) {
+        ft_union ft;
+        __m256 v0 = (__m256)__lasx_xvld( x , 0 );
+        __m256 v1 = (__m256)__lasx_xvld( x , 32 );
+        __m256 v2 = (__m256)__lasx_xvld( x , 64 );
+        __m256 v3 = (__m256)__lasx_xvld( x , 96 );
+        x += 32;
+
+        // Compute max(abs(e)) for the block
+        const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f );
+        __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) );
+        max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) );
+
+        __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) );
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
+        __m128 tmp = max4;
+        max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
+        ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
+        const float max_scalar = ft.f;
+
+        // Quantize these floats
+        const float d = max_scalar / 127.f;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f;
+        const __m256 mul = __lasx_xvreplfr2vr_s( id );
+
+        // Apply the multiplier
+        v0 = __lasx_xvfmul_s( v0, mul );
+        v1 = __lasx_xvfmul_s( v1, mul );
+        v2 = __lasx_xvfmul_s( v2, mul );
+        v3 = __lasx_xvfmul_s( v3, mul );
+
+        // Round to nearest integer
+        __m256i i0 = __lasx_xvftintrne_w_s( v0 );
+        __m256i i1 = __lasx_xvftintrne_w_s( v1 );
+        __m256i i2 = __lasx_xvftintrne_w_s( v2 );
+        __m256i i3 = __lasx_xvftintrne_w_s( v3 );
+
+        __m128i ni0 = lasx_extracti128(i0, 0);
+        __m128i ni1 = lasx_extracti128( i0, 1);
+        __m128i ni2 = lasx_extracti128( i1, 0);
+        __m128i ni3 = lasx_extracti128( i1, 1);
+        __m128i ni4 = lasx_extracti128( i2, 0 );
+        __m128i ni5 = lasx_extracti128( i2, 1);
+        __m128i ni6 = lasx_extracti128( i3, 0);
+        __m128i ni7 = lasx_extracti128( i3, 1);
+
+        // Compute the sum of the quants and set y[i].s
+        const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3));
+        const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7));
+        y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1)));
+
+        // Convert int32 to int16
+        ni0 = lsx_packs_w( ni0, ni1 );
+        ni2 = lsx_packs_w( ni2, ni3 );
+        ni4 = lsx_packs_w( ni4, ni5 );
+        ni6 = lsx_packs_w( ni6, ni7 );
+        // Convert int16 to int8
+        ni0 = lsx_packs_h( ni0, ni2 );
+        ni4 = lsx_packs_h( ni4, ni6 );
+
+        __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
+        __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
+    }
+#else
+    GGML_UNUSED(nb);
+    // scalar
+    quantize_row_q8_1_ref(x, y, k);
+#endif
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+    assert(fabsf(fval) <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
+        const float * restrict qw) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < GROUP_MAX_EPS) { // all zero
+        for (int i = 0; i < n; ++i) {
+            L[i] = 0;
+        }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (rmse_type == 0) {
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+        }
+        return 1/iscale;
+    }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
+    float sumlx = 0;
+    float suml2 = 0;
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 0; i < n; ++i) {
+#else
+    for (int i = 0; i < n; ++i) {
+#endif
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+        float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    float scale = suml2 ? sumlx/suml2 : 0.0f;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+    float best = scale * sumlx;
+    for (int is = -9; is <= 9; ++is) {
+        if (is == 0) {
+            continue;
+        }
+        iscale = -(nmax + 0.1f*is) / max;
+        sumlx = suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+            for (int i = 0; i < n; ++i) {
+                int l = nearest_int(iscale * x[i]);
+                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+            }
+            scale = sumlx/suml2; best = scale*sumlx;
+        }
+    }
+    return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < GROUP_MAX_EPS) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (do_rmse) {
+        float sumlx = 0;
+        float suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            L[i] = l;
+            float w = x[i]*x[i];
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        for (int itry = 0; itry < 5; ++itry) {
+            int n_changed = 0;
+            for (int i = 0; i < n; ++i) {
+                float w = x[i]*x[i];
+                float slx = sumlx - w*x[i]*L[i];
+                if (slx > 0) {
+                    float sl2 = suml2 - w*L[i]*L[i];
+                    int new_l = nearest_int(x[i] * sl2 / slx);
+                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
+                    if (new_l != L[i]) {
+                        slx += w*x[i]*new_l;
+                        sl2 += w*new_l*new_l;
+                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+                            L[i] = new_l; sumlx = slx; suml2 = sl2;
+                            ++n_changed;
+                        }
+                    }
+                }
+            }
+            if (!n_changed) {
+                break;
+            }
+        }
+        for (int i = 0; i < n; ++i) {
+            L[i] += nmax;
+        }
+        return sumlx / suml2;
+    }
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+    }
+    return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
+    float min = x[0];
+    float max = x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+    }
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = 0;
+        return 0.f;
+    }
+    if (min > 0) min = 0;
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    for (int itry = 0; itry < ntry; ++itry) {
+        float sumlx = 0; int suml2 = 0;
+        bool did_change = false;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            if (l != L[i]) {
+                L[i] = l;
+                did_change = true;
+            }
+            sumlx += (x[i] - min)*l;
+            suml2 += l*l;
+        }
+        scale = sumlx/suml2;
+        float sum = 0;
+        for (int i = 0; i < n; ++i) {
+            sum += x[i] - scale*L[i];
+        }
+        min = alpha*min + (1 - alpha)*sum/n;
+        if (min > 0) min = 0;
+        iscale = 1/scale;
+        if (!did_change) break;
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 1; i < n; ++i) {
+#else
+    for (int i = 1; i < n; ++i) {
+#endif
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+    if (j < 4) {
+        *d = q[j] & 63; *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
+    quantize_row_q2_K_ref(x, vy, k);
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
+    quantize_row_q3_K_ref(x, vy, k);
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q4_K * restrict y = vy;
+    quantize_row_q4_K_ref(x, y, k);
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q5_K * restrict y = vy;
+    quantize_row_q5_K_ref(x, y, k);
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_q6_K * restrict y = vy;
+    quantize_row_q6_K_ref(x, y, k);
+}
+
+// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
+
+void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_tq1_0 * restrict y = vy;
+    quantize_row_tq1_0_ref(x, y, k);
+}
+
+void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) {
+    assert(k % QK_K == 0);
+    block_tq2_0 * restrict y = vy;
+    quantize_row_tq2_0_ref(x, y, k);
+}
+
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
+    quantize_row_q8_K_ref(x, y, k);
+}
+
+//===================================== Dot products =================================
+
+//
+// Helper functions
+//
+#if __AVX__ || __AVX2__ || __AVX512F__
+
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return _mm_loadu_si128((const __m128i*)k_shuffle + i);
+}
+#elif defined(__loongarch_asx)
+// shuffles to pick the required scales in dot products
+static inline __m256i get_scale_shuffle_q3k(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,     2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,     6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,    10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,    14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
+    };
+    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
+}
+static inline __m256i get_scale_shuffle_k4(int i) {
+    static const uint8_t k_shuffle[256] = {
+         0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
+         2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
+         4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
+         6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
+         8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
+        10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
+        12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
+        14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
+    };
+    return __lasx_xvld((const __m256i*)k_shuffle + i, 0);
+}
+static inline __m128i get_scale_shuffle(int i) {
+    static const uint8_t k_shuffle[128] = {
+         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+         2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
+         4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
+         6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
+         8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
+        10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
+        12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
+        14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
+    };
+    return __lsx_vld((const __m128i*)k_shuffle + i, 0);
+}
+#endif
+
+void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_0 * restrict vx0 = vx;
+        const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx);
+        const block_q8_0 * restrict vy0 = vy;
+        const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q4_0 * restrict b_x0 = &vx0[i];
+            const block_q4_0 * restrict b_x1 = &vx1[i];
+            const block_q8_0 * restrict b_y0 = &vy0[i];
+            const block_q8_0 * restrict b_y1 = &vy1[i];
+
+            const uint8x16_t m4b = vdupq_n_u8(0x0F);
+            const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+            // 4-bit -> 8-bit
+            const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+            const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+            const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+            const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+            // sub 8
+            const int8x16_t x0_l = vsubq_s8(v0_0l, s8b);
+            const int8x16_t x0_h = vsubq_s8(v0_0h, s8b);
+            const int8x16_t x1_l = vsubq_s8(v0_1l, s8b);
+            const int8x16_t x1_h = vsubq_s8(v0_1h, s8b);
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
+            float32x4_t scale = vld1q_f32(_scale);
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        vst1_f32(s,      vget_low_f32 (sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+
+        return;
+    }
+#endif
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    const int vector_length = ggml_cpu_get_sve_cnt()*8;
+
+    // VLA Implementation using switch case
+    switch (vector_length) {
+        case 128:
+            {
+                // predicate for activating higher lanes for 4 float32 elements
+                const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F));
+                    const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04));
+                    const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F));
+                    const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04));
+
+                    // sub 8
+                    const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8);
+                    const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8);
+                    const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8);
+                    const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8);
+
+                    // load y
+                    const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16);
+                    const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs);
+                    const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4,
+                                    svdot_s32(svdup_n_s32(0), qx0ls, qy0l),
+                                    svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4,
+                                    svdot_s32(svdup_n_s32(0), qx1ls, qy1l),
+                                    svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 256:
+            {
+                // predicate for activating higher lanes for 16 int8 elements
+                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                // predicate for activating lower lanes for  16 int8 elements
+                const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
+                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
+
+                    // sub 8
+                    const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
+                    const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 512:
+            {
+                // predicate for activating higher lanes for 32 int8 elements
+                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
+
+                // predicate for activating higher lanes for 16 int8 elements
+                const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
+                // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes
+                const svbool_t pl16 = svnot_b_z(ph32, ph16);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q4_0 * restrict x0 = &x[ib + 0];
+                    const block_q4_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
+                    const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
+
+                    // 4-bit -> 8-bit
+                    const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04));
+                    const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04));
+
+                    // sub 8
+                    const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8);
+                    const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(ph32, y0->qs);
+                    const svint8_t qy1 = svld1_s8(ph32, y1->qs);
+
+                    // dot product
+                    sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32,
+                                svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32,
+                                svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
+            } break;
+        default:
+            assert(false && "Unsupported vector length");
+            break;
+    }
+
+#elif defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q4_0 * restrict x0 = &x[ib + 0];
+        const block_q4_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib + 0];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // sub 8
+        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // dot product into int32x4_t
+        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
+        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = _mm256_set1_epi8( 8 );
+        qx = _mm256_sub_epi8( qx, off );
+
+        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+    __m256 accum = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
+        const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
+        const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
+        const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
+
+        const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+        const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+        const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+        const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+        const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
+        const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
+        const __m256 p =  sum_i16_pairs_float(p_2, p_1);
+
+        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+#elif defined(__SSSE3__)
+    // set constants
+    const __m128i lowMask = _mm_set1_epi8(0xF);
+    const __m128i off = _mm_set1_epi8(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = _mm_setzero_ps();
+    __m128 acc_1 = _mm_setzero_ps();
+    __m128 acc_2 = _mm_setzero_ps();
+    __m128 acc_3 = _mm_setzero_ps();
+
+    for (; ib + 1 < nb; ib += 2) {
+        _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs);
+
+        __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+        __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
+        bx_0 = _mm_sub_epi8(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+        __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
+        bx_1 = _mm_sub_epi8(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+
+        const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+
+        __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+        __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        bx_2 = _mm_sub_epi8(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+        __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16));
+        bx_3 = _mm_sub_epi8(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = _mm_cvtepi32_ps(i32_0);
+        __m128 p1 = _mm_cvtepi32_ps(i32_1);
+        __m128 p2 = _mm_cvtepi32_ps(i32_2);
+        __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+        __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+        __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+        __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = _mm_add_ps(p0_d, acc_0);
+        acc_1 = _mm_add_ps(p1_d, acc_1);
+        acc_2 = _mm_add_ps(p2_d, acc_2);
+        acc_3 = _mm_add_ps(p3_d, acc_3);
+    }
+
+    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (; ib < nb; ++ib) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        // subtract offset
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector signed char v8 = vec_splats((signed char)0x8);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed char q4x0 = vec_and(qxs, lowMask);
+        vector signed char q4x1 = vec_sr(qxs, v4);
+
+        q4x0 = vec_sub(q4x0, v8);
+        q4x1 = vec_sub(q4x1, v8);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi0 = vec_sum4s(qv1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+
+        // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+        const __m256i off = __lasx_xvreplgr2vr_b( 8 );
+        qx = __lasx_xvsub_b( qx, off );
+
+        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = __lasx_xvfmadd_s( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__loongarch_sx)
+    // set constants
+    const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
+    const __m128i off = __lsx_vreplgr2vr_b(8);
+
+    // Initialize accumulator with zeros
+    __m128 acc_0 = __lsx_vldi(0);
+    __m128 acc_1 = __lsx_vldi(0);
+    __m128 acc_2 = __lsx_vldi(0);
+    __m128 acc_3 = __lsx_vldi(0);
+
+    for (; ib + 1 < nb; ib += 2) {
+
+        // Compute combined scale for the block 0 and 1
+        const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+
+        const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
+
+        __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1);
+        __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0);
+        bx_0 = __lsx_vsub_b(bx_0, off);
+        const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+        __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4));
+        __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0);
+        bx_1 = __lsx_vsub_b(bx_1, off);
+        const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+        //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+        //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+        // Compute combined scale for the block 2 and 3
+        const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+
+        const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
+
+        __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3);
+        __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0);
+        bx_2 = __lsx_vsub_b(bx_2, off);
+        const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+        __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4));
+        __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0);
+        bx_3 = __lsx_vsub_b(bx_3, off);
+        const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+        // Convert int32_t to float
+        __m128 p0 = __lsx_vffint_s_w(i32_0);
+        __m128 p1 = __lsx_vffint_s_w(i32_1);
+        __m128 p2 = __lsx_vffint_s_w(i32_2);
+        __m128 p3 = __lsx_vffint_s_w(i32_3);
+
+        // Apply the scale
+        __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 );
+        __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 );
+        __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 );
+        __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 );
+
+        // Acummulate
+        acc_0 = __lsx_vfadd_s(p0_d, acc_0);
+        acc_1 = __lsx_vfadd_s(p1_d, acc_1);
+        acc_2 = __lsx_vfadd_s(p2_d, acc_2);
+        acc_3 = __lsx_vfadd_s(p3_d, acc_3);
+    }
+
+    sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#endif
+    for (; ib < nb; ++ib) {
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[ib].qs[j] & 0x0F) - 8;
+            const int v1 = (x[ib].qs[j] >>   4) - 8;
+
+            sumi0 += (v0 * y[ib].qs[j]);
+            sumi1 += (v1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q4_1 * restrict vx0 = vx;
+        const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx);
+        const block_q8_1 * restrict vy0 = vy;
+        const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by);
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+        float32x4_t summs0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q4_1 * restrict b_x0 = &vx0[i];
+            const block_q4_1 * restrict b_x1 = &vx1[i];
+            const block_q8_1 * restrict b_y0 = &vy0[i];
+            const block_q8_1 * restrict b_y1 = &vy1[i];
+
+            float32_t summs_t[4] = {
+                GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
+                GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
+                GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
+                GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
+            };
+            summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
+
+            const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+            const uint8x16_t v0_0 = vld1q_u8(b_x0->qs);
+            const uint8x16_t v0_1 = vld1q_u8(b_x1->qs);
+
+            // 4-bit -> 8-bit
+            const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+            const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+            const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+            const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            // mmla into int32x4_t
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
+            float32x4_t scale = vld1q_f32(_scale);
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        sumv2 = vaddq_f32(sumv2, summs0);
+
+        vst1_f32(s,      vget_low_f32 (sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+
+        return;
+    }
+#endif
+
+    int ib = 0;
+    float sumf = 0;
+
+    // TODO: add WASM SIMD
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs = 0;
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q4_1 * restrict x0 = &x[ib + 0];
+        const block_q4_1 * restrict x1 = &x[ib + 1];
+        const block_q8_1 * restrict y0 = &y[ib + 0];
+        const block_q8_1 * restrict y1 = &y[ib + 1];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // dot product into int32x4_t
+        const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
+        const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
+#elif defined(__AVX2__) || defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const float d0 = GGML_FP16_TO_FP32(x[ib].d);
+        const float d1 = GGML_FP16_TO_FP32(y[ib].d);
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs );
+
+        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
+
+        // Accumulate d0*d1*x*y
+#if defined(__AVX2__)
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+        acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (; ib < nb; ++ib) {
+        // load elements
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        // mask and store lower part of x, and then upper part
+        vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
+        vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f};
+        vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask);
+        vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4);
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_msum(q8y0, q4x0, vsumi0);
+        vsumi0 = vec_msum(q8y1, q4x1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    float summs = 0;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const float d0 = GGML_FP16_TO_FP32(x[ib].d);
+        const float d1 = GGML_FP16_TO_FP32(y[ib].d);
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        const __m256 d0v = __lasx_xvreplfr2vr_s( d0 );
+        const __m256 d1v = __lasx_xvreplfr2vr_s( d1 );
+
+        // Compute combined scales
+        const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0);
+
+        const __m256 xy = mul_sum_us8_pairs_float(qx, qy);
+
+        // Accumulate d0*d1*x*y
+        acc = __lasx_xvfmadd_s( d0d1, xy, acc );
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#endif
+    for (; ib < nb; ++ib) {
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int v0 = (x[ib].qs[j] & 0x0F);
+            const int v1 = (x[ib].qs[j] >>   4);
+
+            sumi0 += (v0 * y[ib].qs[j]);
+            sumi1 += (v1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    int ib = 0;
+    float sumf = 0;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q5_0 * restrict x0 = &x[ib];
+        const block_q5_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        // extract the 5th bit via lookup table ((!b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_1[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_1[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_1[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_1[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_1[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_1[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (; ib < nb; ++ib) {
+        const block_q5_0 * restrict x0 = &x[ib];
+        const block_q8_0 * restrict y0 = &y[ib];
+
+        const v128_t m4b  = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
+        const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
+        const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
+                        wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0));
+        qx = _mm256_or_si256(qx, bxhi);
+
+        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_fmadd_ps(d, q, acc);
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8((char)0xF0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_andnot_si128(bxhil, mask);
+        bxhih = _mm_andnot_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx_0);
+        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx_0 = MM256_SET_M128I(bxh, bxl);
+
+        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
+
+        /* Multiply q with scale and accumulate */
+        acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // These temporary registers are for masking and shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
+
+    vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl);
+    vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (; ib < nb; ++ib) {
+        memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+
+        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl);
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+
+        // ((qh & (1u << (j + 16))) >> (j + 12));
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl);
+        vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl);
+        vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector unsigned char v4 = vec_splats((unsigned char)4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])};
+        vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])};
+
+        vector signed char qh0 = (vector signed char)aux64x2_0;
+        vector signed char qh1 = (vector signed char)aux64x2_1;
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+        vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0);
+        vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1);
+
+        vector signed char q8y0 = vec_xl(  0, y[ib].qs);
+        vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+        vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1));
+
+        qv0 = vec_add(qv0, qv1);
+
+        vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0));
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        /* Compute combined scale for the block */
+        const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0));
+        qx = __lasx_xvor_v(qx, bxhi);
+
+        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        /* Multiply q with scale and accumulate */
+        acc = __lasx_xvfmadd_s(d, q, acc);
+    }
+
+    sumf = hsum_float_8(acc);
+#endif
+    for (; ib < nb; ++ib) {
+        uint32_t qh;
+        memcpy(&qh, x[ib].qh, sizeof(qh));
+
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+            const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
+
+            const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16);
+            const int32_t x1 = (int8_t)(((x[ib].qs[j] >>   4) | xh_1) - 16);
+
+            sumi0 += (x0 * y[ib].qs[j]);
+            sumi1 += (x1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi;
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_1;
+    const int nb = n / qk;
+
+    int ib = 0;
+    float sumf = 0;
+
+    assert(n % qk == 0);
+    assert(qk == QK5_1);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_1 * restrict x = vx;
+    const block_q8_1 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    float summs0 = 0.0f;
+    float summs1 = 0.0f;
+
+    uint32_t qh0;
+    uint32_t qh1;
+
+    uint64_t tmp0[4];
+    uint64_t tmp1[4];
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q5_1 * restrict x0 = &x[ib];
+        const block_q5_1 * restrict x1 = &x[ib + 1];
+        const block_q8_1 * restrict y0 = &y[ib];
+        const block_q8_1 * restrict y1 = &y[ib + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+
+        summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
+        summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s);
+
+        // extract the 5th bit via lookup table ((b) << 4)
+        memcpy(&qh0, x0->qh, sizeof(qh0));
+        memcpy(&qh1, x1->qh, sizeof(qh1));
+
+        tmp0[0] = table_b2b_0[(qh0 >>  0) & 0xFF];
+        tmp0[1] = table_b2b_0[(qh0 >>  8) & 0xFF];
+        tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
+        tmp0[3] = table_b2b_0[(qh0 >> 24)       ];
+
+        tmp1[0] = table_b2b_0[(qh1 >>  0) & 0xFF];
+        tmp1[1] = table_b2b_0[(qh1 >>  8) & 0xFF];
+        tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
+        tmp1[3] = table_b2b_0[(qh1 >> 24)       ];
+
+        const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
+        const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
+        const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
+        const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // add high bit
+        const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
+        const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
+        const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
+        const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
+                        ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
+#elif defined(__wasm_simd128__)
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    float summs = 0.0f;
+
+    uint32_t qh;
+    uint64_t tmp[4];
+
+    // TODO: check if unrolling this is better
+    for (; ib < nb; ++ib) {
+        const block_q5_1 * restrict x0 = &x[ib];
+        const block_q8_1 * restrict y0 = &y[ib];
+
+        summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s);
+
+        const v128_t m4b = wasm_i8x16_splat(0x0F);
+
+        // extract the 5th bit
+        memcpy(&qh, x0->qh, sizeof(qh));
+
+        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh >> 24)       ];
+
+        const v128_t qhl = wasm_v128_load(tmp + 0);
+        const v128_t qhh = wasm_v128_load(tmp + 2);
+
+        const v128_t v0 = wasm_v128_load(x0->qs);
+
+        // 4-bit -> 8-bit
+        const v128_t v0l = wasm_v128_and (v0, m4b);
+        const v128_t v0h = wasm_u8x16_shr(v0, 4);
+
+        // add high bit
+        const v128_t v0lf = wasm_v128_or(v0l, qhl);
+        const v128_t v0hf = wasm_v128_or(v0h, qhh);
+
+        // load y
+        const v128_t v1l = wasm_v128_load(y0->qs);
+        const v128_t v1h = wasm_v128_load(y0->qs + 16);
+
+        // int8x16 -> int16x8
+        const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf);
+        const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf);
+        const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf);
+        const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf);
+
+        const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l);
+        const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l);
+        const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
+        const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
+
+        // dot product
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs;
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
+        qx = _mm256_or_si256(qx, bxhi);
+
+        const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
+        const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(qx, qy);
+
+        acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+    __m128i mask = _mm_set1_epi8(0x10);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d));
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs);
+        const __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        __m128i bxhil = _mm256_castsi256_si128(bxhi);
+        __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+        bxhil = _mm_and_si128(bxhil, mask);
+        bxhih = _mm_and_si128(bxhih, mask);
+        __m128i bxl = _mm256_castsi256_si128(bx_0);
+        __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
+        bxl = _mm_or_si128(bxl, bxhil);
+        bxh = _mm_or_si128(bxh, bxhih);
+        bx_0 = MM256_SET_M128I(bxh, bxl);
+
+        const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d));
+        const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
+
+        acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    uint32_t qh;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    // temporary registers for shift operations
+    vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
+    vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl);
+
+    for (; ib < nb; ++ib) {
+        memcpy(&qh, x[ib].qh, sizeof(uint32_t));
+
+        // load qh
+        vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl);
+
+        // ((qh >> (j +  0)) << 4) & 0x10;
+        vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl);
+        vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl);
+        vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl);
+
+        // ((qh >> (j + 12))     ) & 0x10;
+        vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl);
+        vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl);
+
+        // narrowing
+        vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl);
+        vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl);
+
+        vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl);
+        vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl);
+
+        // load
+        vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl);
+
+        vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl);
+        vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl);
+
+        vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl);
+        vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl);
+
+        vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl);
+        vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl);
+
+        vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a);
+        vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l);
+
+        vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl);
+        vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m));
+        vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f};
+        vsumf0 = vec_madd(vxmin, vys, vsumf0);
+
+        vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])};
+        vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])};
+
+        vector signed char qh0 = (vector signed char)aux64x2_0;
+        vector signed char qh1 = (vector signed char)aux64x2_1;
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+
+        vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0);
+        vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1);
+
+        vector signed char q8y0 = vec_xl(  0, y[ib].qs);
+        vector signed char q8y1 = vec_xl( 16, y[ib].qs);
+
+        vector signed int vsumi0 = v0;
+
+        vsumi0 = vec_msum(q8y0, q5x0, vsumi0);
+        vsumi0 = vec_msum(q8y1, q5x1, vsumi0);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    float summs = 0.0f;
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d));
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        __m256i qx = bytes_from_nibbles_32(x[ib].qs);
+        __m256i bxhi = bytes_from_bits_32(x[ib].qh);
+        bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10));
+        qx = __lasx_xvor_v(qx, bxhi);
+
+        const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d));
+        const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_us8_pairs_float(qx, qy);
+
+        acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc);
+    }
+
+    sumf = hsum_float_8(acc) + summs;
+#endif
+    for (; ib < nb; ++ib) {
+        uint32_t qh;
+        memcpy(&qh, x[ib].qh, sizeof(qh));
+
+        int sumi0 = 0;
+        int sumi1 = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0;
+            const int32_t x1 = (x[ib].qs[j] >>  4) | xh_1;
+
+            sumi0 += (x0 * y[ib].qs[j]);
+            sumi1 += (x1 * y[ib].qs[j + qk/2]);
+        }
+
+        int sumi = sumi0 + sumi1;
+        sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s);
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+
+    assert(n % qk == 0);
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    assert((nrc == 2) || (nrc == 1));
+#else
+    assert(nrc == 1);
+#endif
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q8_0 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    if (nrc == 2) {
+        const block_q8_0 * restrict vx0 = vx;
+        const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx);
+        const block_q8_0 * restrict vy0 = vy;
+        const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by);
+
+        float32x4_t sumv0 = vdupq_n_f32(0.0f);
+
+        for (int i = 0; i < nb; i++) {
+            const block_q8_0 * restrict b_x0 = &vx0[i];
+            const block_q8_0 * restrict b_y0 = &vy0[i];
+
+            const block_q8_0 * restrict b_x1 = &vx1[i];
+            const block_q8_0 * restrict b_y1 = &vy1[i];
+
+            const int8x16_t x0_l = vld1q_s8(b_x0->qs);
+            const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16);
+            const int8x16_t x1_l = vld1q_s8(b_x1->qs);
+            const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16);
+
+            // load y
+            const int8x16_t y0_l = vld1q_s8(b_y0->qs);
+            const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16);
+            const int8x16_t y1_l = vld1q_s8(b_y1->qs);
+            const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
+
+            float32_t _scale[4] = {
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
+                GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
+            };
+            float32x4_t scale = vld1q_f32(_scale);
+
+            int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+            int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
+
+            int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+            int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h)));
+
+            int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+            int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l)));
+
+            int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+            int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
+
+            sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
+                                                l1, r1)), l2, r2)), l3, r3))), scale);
+        }
+
+        float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
+        float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
+
+        vst1_f32(s,      vget_low_f32 (sumv2));
+        vst1_f32(s + bs, vget_high_f32(sumv2));
+
+        return;
+    }
+#endif
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined(__ARM_FEATURE_SVE)
+    svfloat32_t sumv0 = svdup_n_f32(0.0f);
+    svfloat32_t sumv1 = svdup_n_f32(0.0f);
+
+    const int vector_length = ggml_cpu_get_sve_cnt()*8;
+
+    //VLA Implemenation for SVE
+    switch (vector_length) {
+        case 128:
+            {
+                // predicate for activating lanes for 16 Int8 elements
+                const svbool_t ph16 = svptrue_pat_b8 (SV_VL16);
+                const svbool_t pl16 = svptrue_pat_b32(SV_VL4);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svint8_t qx0_0 = svld1_s8(ph16, x0->qs);
+                    const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16);
+                    const svint8_t qx1_0 = svld1_s8(ph16, x1->qs);
+                    const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16);
+
+                    // load y
+                    const svint8_t qy0_0 = svld1_s8(ph16, y0->qs);
+                    const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16);
+                    const svint8_t qy1_0 = svld1_s8(ph16, y1->qs);
+                    const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16);
+
+                    sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16,
+                                    svdot_s32(svdup_n_s32(0), qx0_0, qy0_0),
+                                    svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16,
+                                    svdot_s32(svdup_n_s32(0), qx1_0, qy1_0),
+                                    svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1));
+            } break;
+        case 256:
+            {
+                //printf("sve256");
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    // load x
+                    const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
+                    const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
+
+                    // load y
+                    const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
+                    const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
+
+                    sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+                    sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(),
+                                svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
+            } break;
+        case 512:
+            {
+                // predicate for activating high 256 bit
+                const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
+                // predicate for activating low 256 bit
+                const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32);
+
+                // predicate for activating high lanes for 8 float32 elements
+                const svbool_t ph8 = svptrue_pat_b32(SV_VL8);
+                // predicate for activating low lanes for 8 float32 elements
+                const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8);
+
+                svfloat32_t sumv00 = svdup_n_f32(0.0f);
+
+                for (; ib + 1 < nb; ib += 2) {
+                    const block_q8_0 * restrict x0 = &x[ib + 0];
+                    const block_q8_0 * restrict x1 = &x[ib + 1];
+                    const block_q8_0 * restrict y0 = &y[ib + 0];
+                    const block_q8_0 * restrict y1 = &y[ib + 1];
+
+                    //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits
+                    // and add them to make one 64 element vector
+                    // load x
+                    const svint8_t qx_32 = svld1_s8(ph32, x0->qs);
+                          svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2);
+
+                    qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64);
+
+                    // load y
+                    const svint8_t qy_32 = svld1_s8(ph32, y0->qs);
+                          svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2);
+
+                    qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64);
+
+                    // scale creation
+                    const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d);
+                    const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d);
+
+                    // duplicate deq1 in first half of vector and deq2 in second half of vector
+                    const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2);
+
+                    const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64));
+
+                    sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp);
+                }
+
+                sumf = svaddv_f32(svptrue_b32(), sumv00);
+                break;
+            }
+        default:
+            assert(false && "Unsupported vector length");
+            break;
+    }
+#elif defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q8_0 * restrict x0 = &x[ib + 0];
+        const block_q8_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib + 0];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        const int8x16_t x0_0 = vld1q_s8(x0->qs);
+        const int8x16_t x0_1 = vld1q_s8(x0->qs + 16);
+        const int8x16_t x1_0 = vld1q_s8(x1->qs);
+        const int8x16_t x1_1 = vld1q_s8(x1->qs + 16);
+
+        // load y
+        const int8x16_t y0_0 = vld1q_s8(y0->qs);
+        const int8x16_t y0_1 = vld1q_s8(y0->qs + 16);
+        const int8x16_t y1_0 = vld1q_s8(y1->qs);
+        const int8x16_t y1_1 = vld1q_s8(y1->qs + 16);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
+                        ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
+                        ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
+                        ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+        __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs);
+        __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        // Multiply q with scale and accumulate
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#elif defined(__AVX__)
+    __m256 accum = _mm256_setzero_ps();
+
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs);
+        const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1);
+        const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1);
+        const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
+        const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1);
+        const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1);
+        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+#elif defined(__riscv_v_intrinsic)
+    size_t vl = __riscv_vsetvl_e8m1(qk);
+
+    for (; ib < nb; ++ib) {
+        // load elements
+        vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl);
+        vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl);
+
+        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
+
+        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
+    }
+#elif defined(__POWER9_VECTOR__)
+    const vector signed int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char q8x0 = vec_xl( 0, x[ib].qs);
+        vector signed char q8x1 = vec_xl(16, x[ib].qs);
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed short qv0 = vec_mule(q8x0, q8y0);
+        vector signed short qv1 = vec_mulo(q8x0, q8y0);
+        vector signed short qv2 = vec_mule(q8x1, q8y1);
+        vector signed short qv3 = vec_mulo(q8x1, q8y1);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi1 = vec_sum4s(qv1, vsumi1);
+        vsumi0 = vec_sum4s(qv2, vsumi0);
+        vsumi1 = vec_sum4s(qv3, vsumi1);
+
+        vsumi0 = vec_add(vsumi0, vsumi1);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+    // Initialize accumulator with zeros
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    // Main loop
+    for (; ib < nb; ++ib) {
+        // Compute combined scale for the block
+        const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+        __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0);
+        __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0);
+
+        const __m256 q = mul_sum_i8_pairs_float(qx, qy);
+
+        // Multiply q with scale and accumulate
+        acc = __lasx_xvfmadd_s( d, q, acc );
+    }
+
+    sumf = hsum_float_8(acc);
+#endif
+    for (; ib < nb; ++ib) {
+        int sumi = 0;
+
+        for (int j = 0; j < qk; j++) {
+            sumi += x[ib].qs[j]*y[ib].qs[j];
+        }
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d));
+    }
+
+    *s = sumf;
+}
+
+void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_tq1_0 * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+    float sumf = 0.0f;
+
+    uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
+
+    const uint8x16_t shift = vld1q_u8(k_shift);
+
+    for (int i = 0; i < nb; ++i) {
+#if defined(__ARM_FEATURE_DOTPROD)
+        int32x4_t sumi0 = vdupq_n_s32(0);
+        int32x4_t sumi1 = vdupq_n_s32(0);
+#else
+        int16x8_t sumi0 = vdupq_n_s16(0);
+        int16x8_t sumi1 = vdupq_n_s16(0);
+#endif
+
+        // first 32 bytes of 5 elements
+        {
+            uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
+            uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
+            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
+            uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
+            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
+            uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
+            uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
+            uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
+            uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
+            uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
+
+            // multiply by 3 and keep the 2 bits above 8 bits
+            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
+            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
+            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
+            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
+            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
+            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
+            int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
+            int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
+            int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
+            int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
+
+            const int8x16_t qy0 = vld1q_s8(y[i].qs +   0);
+            const int8x16_t qy1 = vld1q_s8(y[i].qs +  16);
+            const int8x16_t qy2 = vld1q_s8(y[i].qs +  32);
+            const int8x16_t qy3 = vld1q_s8(y[i].qs +  48);
+            const int8x16_t qy4 = vld1q_s8(y[i].qs +  64);
+            const int8x16_t qy5 = vld1q_s8(y[i].qs +  80);
+            const int8x16_t qy6 = vld1q_s8(y[i].qs +  96);
+            const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
+            const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
+            const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+            sumi0 = vdotq_s32(sumi0, sqx6, qy6);
+            sumi1 = vdotq_s32(sumi1, sqx7, qy7);
+            sumi0 = vdotq_s32(sumi0, sqx8, qy8);
+            sumi1 = vdotq_s32(sumi1, sqx9, qy9);
+#else
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
+#endif
+        }
+
+        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
+        {
+            uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
+            uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
+            uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
+            uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
+            uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
+            uint32_t qh;
+            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
+            uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
+            qx5 = vmulq_u8(qx5, shift);
+
+            // multiply by 3 and keep the 2 bits above 8 bits
+            int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
+            int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
+            int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
+            int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
+            int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
+            int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
+
+            const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
+            const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
+            const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
+            const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
+            const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
+            const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+#else
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+#endif
+        }
+
+        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
+        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumi0 = vaddq_s32(sumi0, sumi1);
+        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
+
+        sumf += d * (float) vaddvq_s32(sumi0);
+#else
+        sumi0 = vaddq_s16(sumi0, sumi1);
+        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
+
+        sumf += d * (float) vaddlvq_s16(sumi0);
+#endif
+    }
+
+    *s = sumf;
+
+#elif defined(__AVX2__)
+    __m256 sumf = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+        // 16-bit sums
+        __m256i sumi0 = _mm256_setzero_si256();
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+
+        // first 32 bytes of 5 elements
+        {
+            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs));
+            // 8-bit multiplies with shifts, masks and adds
+            __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3
+            __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9
+            __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9
+            __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9
+
+            // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits?
+
+            // Cancel the +1 from avg so that it behaves like a halving add
+            qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1));
+            qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1));
+            qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1));
+            qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1));
+            qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1));
+            // Multiply by 3 and get the top 2 bits
+            qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256()));
+            qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256()));
+            qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256()));
+            qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256()));
+            qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256()));
+            qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3));
+            qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3));
+            qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3));
+            qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3));
+            qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3));
+
+            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs +   0));
+            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  32));
+            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  64));
+            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs +  96));
+            const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128));
+
+            qx0 = _mm256_maddubs_epi16(qx0, qy0);
+            qx1 = _mm256_maddubs_epi16(qx1, qy1);
+            qx2 = _mm256_maddubs_epi16(qx2, qy2);
+            qx3 = _mm256_maddubs_epi16(qx3, qy3);
+            qx4 = _mm256_maddubs_epi16(qx4, qy4);
+
+            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
+            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
+            sumi2 = _mm256_add_epi16(sumi2, qx4);
+        }
+
+        // last 16 bytes of 5-element, along with the 4 bytes of 4 elements
+        {
+            __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32));
+            uint32_t qh;
+            memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
+            __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh));
+            __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3
+            __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9
+            __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9
+            __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9
+            __m256i qx01 = MM256_SET_M128I(qx1, qx0);
+            __m256i qx23 = MM256_SET_M128I(qx3, qx2);
+
+            // avx2 does not have 8-bit multiplies, so 16-bit it is.
+            qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1));
+            qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF));
+            __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1));
+
+            __m256i qx45 = MM256_SET_M128I(qx5, qx4);
+
+            // Cancel the +1 from avg so that it behaves like a halving add
+            qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1));
+            qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1));
+            qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1));
+            // Multiply by 3 and get the top 2 bits
+            qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256()));
+            qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256()));
+            qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256()));
+            qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3));
+            qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3));
+            qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3));
+
+            const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160));
+            const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192));
+            const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224));
+
+            qx01 = _mm256_maddubs_epi16(qx01, qy01);
+            qx23 = _mm256_maddubs_epi16(qx23, qy23);
+            qx45 = _mm256_maddubs_epi16(qx45, qy45);
+
+            sumi0 = _mm256_add_epi16(sumi0, qx01);
+            sumi1 = _mm256_add_epi16(sumi1, qx23);
+            sumi2 = _mm256_add_epi16(sumi2, qx45);
+        }
+
+        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+
+        sumi0 = _mm256_sub_epi16(sumi0, ysum);
+        sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2));
+        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
+
+        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
+    }
+
+    *s = hsum_float_8(sumf);
+
+#else
+    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
+
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        int sum = 0;
+
+        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
+            for (size_t l = 0; l < 5; ++l) {
+                for (size_t m = 0; m < 32; ++m) {
+                    uint8_t q = x[i].qs[j + m] * pow3[l];
+                    uint16_t xi = ((uint16_t) q * 3) >> 8;
+                    sum += (xi - 1) * y[i].qs[j*5 + l*32 + m];
+                }
+            }
+        }
+        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
+            for (size_t l = 0; l < 5; ++l) {
+                for (size_t m = 0; m < 16; ++m) {
+                    uint8_t q = x[i].qs[j + m] * pow3[l];
+                    uint16_t xi = ((uint16_t) q * 3) >> 8;
+                    sum += (xi - 1) * y[i].qs[j*5 + l*16 + m];
+                }
+            }
+        }
+
+        for (size_t l = 0; l < 4; ++l) {
+            for (size_t j = 0; j < sizeof(x->qh); ++j) {
+                uint8_t q = x[i].qh[j] * pow3[l];
+                uint16_t xi = ((uint16_t) q * 3) >> 8;
+                sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j];
+            }
+        }
+
+        sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d);
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_tq2_0 * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+    float sumf = 0.0f;
+
+    const uint8x16_t m3 = vdupq_n_u8(3);
+
+    for (int i = 0; i < nb; ++i) {
+#if defined(__ARM_FEATURE_DOTPROD)
+        int32x4_t sumi0 = vdupq_n_s32(0);
+        int32x4_t sumi1 = vdupq_n_s32(0);
+#else
+        int16x8_t sumi0 = vdupq_n_s16(0);
+        int16x8_t sumi1 = vdupq_n_s16(0);
+#endif
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
+            uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
+            uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
+            uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
+            uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
+            uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
+            uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
+            uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
+
+            int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
+            int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
+            int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
+            int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
+            int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
+            int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
+            int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
+            int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
+
+            const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 +   0);
+            const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 +  16);
+            const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 +  32);
+            const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 +  48);
+            const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 +  64);
+            const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 +  80);
+            const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 +  96);
+            const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+            sumi0 = vdotq_s32(sumi0, sqx0, qy0);
+            sumi1 = vdotq_s32(sumi1, sqx1, qy1);
+            sumi0 = vdotq_s32(sumi0, sqx2, qy2);
+            sumi1 = vdotq_s32(sumi1, sqx3, qy3);
+            sumi0 = vdotq_s32(sumi0, sqx4, qy4);
+            sumi1 = vdotq_s32(sumi1, sqx5, qy5);
+            sumi0 = vdotq_s32(sumi0, sqx6, qy6);
+            sumi1 = vdotq_s32(sumi1, sqx7, qy7);
+#else
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
+            sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
+            sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
+#endif
+        }
+
+        const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
+        const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        sumi0 = vaddq_s32(sumi0, sumi1);
+        sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
+
+        sumf += d * (float) vaddvq_s32(sumi0);
+#else
+        sumi0 = vaddq_s16(sumi0, sumi1);
+        sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
+
+        sumf += d * (float) vaddlvq_s16(sumi0);
+#endif
+    }
+
+    *s = sumf;
+
+#elif defined(__AVX2__)
+    __m256 sumf = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+        // 16-bit sums, because 256*127 still fits
+        __m256i sumi0 = _mm256_setzero_si256();
+        __m256i sumi1 = _mm256_setzero_si256();
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j));
+            __m256i qx1 = _mm256_srli_epi16(qx0, 2);
+            __m256i qx2 = _mm256_srli_epi16(qx0, 4);
+            __m256i qx3 = _mm256_srli_epi16(qx0, 6);
+
+            // 0, 1, 2 (should not be 3)
+            qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
+            qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
+            qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
+            qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
+
+            const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 +  0));
+            const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
+            const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
+            const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
+
+            qx0 = _mm256_maddubs_epi16(qx0, qy0);
+            qx1 = _mm256_maddubs_epi16(qx1, qy1);
+            qx2 = _mm256_maddubs_epi16(qx2, qy2);
+            qx3 = _mm256_maddubs_epi16(qx3, qy3);
+
+            sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
+            sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
+        }
+
+        const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
+
+        sumi0 = _mm256_add_epi16(sumi0, sumi1);
+        sumi0 = _mm256_sub_epi16(sumi0, ysum);
+        sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
+
+        sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
+    }
+
+    *s = hsum_float_8(sumf);
+
+#else
+    float sumf = 0.0f;
+
+    for (int i = 0; i < nb; ++i) {
+        int32_t sumi = 0;
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            for (size_t l = 0; l < 4; ++l) {
+                for (size_t k = 0; k < 32; ++k) {
+                    sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1);
+                }
+            }
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        sumf += (float) sumi * d;
+    }
+
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q2_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+    const uint8x16_t m3 = vdupq_n_u8(0x3);
+    const uint8x16_t m4 = vdupq_n_u8(0xF);
+
+    const int32x4_t vzero = vdupq_n_s32(0);
+
+    ggml_int8x16x2_t q2bytes;
+    uint8_t aux[16];
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        const uint8_t * restrict sc = x[i].scales;
+
+        const uint8x16_t mins_and_scales = vld1q_u8(sc);
+        const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
+        vst1q_u8(aux, scales);
+
+        const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}};
+        const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
+                                       vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
+        const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
+                                       vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
+        sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
+
+        int isum = 0;
+        int is = 0;
+
+// We use this macro instead of a function call because for some reason
+// the code runs 2-3% slower, even if the function is declared inline
+#define MULTIPLY_ACCUM_WITH_SCALE(index)\
+        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
+        isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
+
+#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
+        q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
+        q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
+        q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
+        MULTIPLY_ACCUM_WITH_SCALE((index));
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
+
+            ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
+            q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
+
+            MULTIPLY_ACCUM_WITH_SCALE(0);
+
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
+            SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
+
+            is += 8;
+        }
+
+        sum += d * isum;
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m256i mins = _mm256_cvtepi8_epi16(mins8);
+        const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
+
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
+            const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
+            const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
+            const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
+
+            __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
+            __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
+            __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
+            __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
+
+            p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
+            p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
+            p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
+            p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+            p0 = _mm256_add_epi32(p0, p1);
+            p2 = _mm256_add_epi32(p2, p3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(0x3);
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // load mins and scales from block_q2_K.scales[QK_K/16]
+        const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales16 = _mm_and_si128(mins_and_scales, m4);
+        const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
+        const __m128i mins_0 = _mm_cvtepi8_epi16(mins16);
+        const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16));
+
+        // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2
+        const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0]));
+        const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8]));
+
+        // sumf += -dmin * summs in 32bits*8
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc);
+
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales16);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // load 2bits*16*8 from block_q2_K.qs[QK_K/4]
+            __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_0 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+            q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16;
+            const __m128i q2_1 = _mm_and_si128(q2bits, m3);
+            const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3);
+            const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3);
+            const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3);
+
+            // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8
+            __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0);
+            __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1);
+            __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2);
+            __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3);
+            __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4);
+            __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5);
+            __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6);
+            __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7);
+
+            // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7);
+
+            p0 = _mm_add_epi32(p0, p1);
+            p2 = _mm_add_epi32(p2, p3);
+            p4 = _mm_add_epi32(p4, p5);
+            p6 = _mm_add_epi32(p6, p7);
+
+            // isum in 32bits*4*2
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6));
+        }
+
+        // sumf += dall * isum - dmin * summs in 32bits
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+                            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        size_t vl = 16;
+
+        vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
+        vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
+
+        vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
+
+        vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
+        vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
+        vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
+        vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
+        vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+
+        sumf  += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
+
+        vl = 32;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
+
+        uint8_t is=0;
+        int isum=0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load Q2
+            vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
+
+            vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
+            vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl);
+
+            // duplicate scale elements for product
+            vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl);
+            vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl);
+            vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl);
+            vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl);
+
+            vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
+            vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
+            vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
+            vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
+
+            // load Q8
+            vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl);
+            vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl);
+
+            vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
+            vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
+            vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
+            vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
+
+            isum += __riscv_vmv_x_s_i32m1_i32(isum1);
+
+            q2+=32;  q8+=128;  is=8;
+
+        }
+
+        sumf += dall * isum;
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0x3);
+    const vector signed char lowScaleMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales);
+        vector signed char vscales = vec_and(q2xmins, lowScaleMask);
+
+        q2xmins = vec_sr(q2xmins, v4);
+        vector signed short q2xmins0 = vec_unpackh(q2xmins);
+        vector signed short q2xmins1 = vec_unpackl(q2xmins);
+
+        vector signed int prod0 = vec_mule(q2xmins0, q8ysums0);
+        vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0);
+        vector signed int prod2 = vec_mule(q2xmins1, q8ysums1);
+        vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q2);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q2);
+            q2 += 32;
+
+            vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+            vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask);
+            vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask);
+            vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask);
+            vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+            vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask);
+            vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask);
+            vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y02 = vec_xl( 64, q8);
+            vector signed char q8y12 = vec_xl( 80, q8);
+            vector signed char q8y03 = vec_xl( 96, q8);
+            vector signed char q8y13 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed int qv0 = vec_msum(q8y00, q2x00, v0);
+            vector signed int qv1 = vec_msum(q8y01, q2x01, v0);
+            vector signed int qv2 = vec_msum(q8y02, q2x02, v0);
+            vector signed int qv3 = vec_msum(q8y03, q2x03, v0);
+            vector signed int qv4 = vec_msum(q8y10, q2x10, v0);
+            vector signed int qv5 = vec_msum(q8y11, q2x11, v0);
+            vector signed int qv6 = vec_msum(q8y12, q2x12, v0);
+            vector signed int qv7 = vec_msum(q8y13, q2x13, v0);
+
+            vector signed short vscales_07 = vec_unpackh(vscales);
+            vector signed int vscales_03 = vec_unpackh(vscales_07);
+            vector signed int vscales_47 = vec_unpackl(vscales_07);
+            vector signed int vs0 = vec_splat(vscales_03, 0);
+            vector signed int vs1 = vec_splat(vscales_03, 1);
+            vector signed int vs2 = vec_splat(vscales_03, 2);
+            vector signed int vs3 = vec_splat(vscales_03, 3);
+            vector signed int vs4 = vec_splat(vscales_47, 0);
+            vector signed int vs5 = vec_splat(vscales_47, 1);
+            vector signed int vs6 = vec_splat(vscales_47, 2);
+            vector signed int vs7 = vec_splat(vscales_47, 3);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3);
+            vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4);
+            vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5);
+            vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6);
+            vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+    const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q2 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
+        const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
+        const __m256i mins = lasx_ext8_16(mins8);
+        const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
+
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
+
+        const __m256i all_scales = lasx_ext8_16(scales8);
+        const __m128i l_scales = lasx_extracti128(all_scales, 0);
+        const __m128i h_scales = lasx_extracti128(all_scales, 1);
+        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32;
+
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
+            const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
+            const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
+            const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
+
+            __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
+            __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
+            __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
+            __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
+
+            p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
+            p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
+            p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
+            p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
+
+            p0 = __lasx_xvadd_w(p0, p1);
+            p2 = __lasx_xvadd_w(p2, p3);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2));
+        }
+
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * q2 = x[i].qs;
+        const  int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        int summs = 0;
+        for (int j = 0; j < 16; ++j) {
+            summs += y[i].bsums[j] * (sc[j] >> 4);
+        }
+
+        const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        int isum = 0;
+        int is = 0;
+        int d;
+        for (int k = 0; k < QK_K/128; ++k) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+                d = sc[is++] & 0xF;
+                int isuml = 0;
+                for (int l =  0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                d = sc[is++] & 0xF;
+                isuml = 0;
+                for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
+                isum += d * isuml;
+                shift += 2;
+                q8 += 32;
+            }
+            q2 += 32;
+        }
+        sumf += dall * isum - dmin * summs;
+    }
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    const block_q3_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    const uint8x16_t m3b = vdupq_n_u8(0x3);
+    const int32x4_t  vzero = vdupq_n_s32(0);
+
+    const uint8x16_t m0 = vdupq_n_u8(1);
+    const uint8x16_t m1 = vshlq_n_u8(m0, 1);
+    const uint8x16_t m2 = vshlq_n_u8(m0, 2);
+    const uint8x16_t m3 = vshlq_n_u8(m0, 3);
+    const int8_t m32 = 32;
+
+    ggml_int8x16x4_t q3bytes;
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+        ggml_uint8x16x4_t q3h;
+
+        int32_t isum = 0;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
+            const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
+            const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
+            q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
+            q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
+            q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
+
+            scale += 4;
+
+            q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
+            q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
+            q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
+            q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
+
+            q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
+            q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
+            q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
+            q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
+
+            scale += 4;
+
+            if (j == 0) {
+                qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
+                qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
+            }
+
+        }
+        sum += d * isum;
+
+    }
+
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m3 = _mm256_set1_epi8(3);
+    const __m256i mone = _mm256_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    uint32_t aux[3];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
+        const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
+        const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
+        const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)};
+
+        // high bit
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
+
+        // integer accumulator
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+        int is  = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits
+            const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
+
+            // prepare low and high bits
+            const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
+            const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
+            const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
+            const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
+            const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
+            ++bit;
+
+            // load Q8 quants
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            // multiply with scales
+            p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+            // accumulate
+            p16_0 = _mm256_add_epi32(p16_0, p16_1);
+            p16_2 = _mm256_add_epi32(p16_2, p16_3);
+            sumi  = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
+
+        }
+
+        // multiply with block scale and accumulate
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i mone = _mm_set1_epi8(1);
+    const __m128i m32 = _mm_set1_epi8(32);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    const uint32_t *aux;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Set up scales
+        aux = (const uint32_t *)x[i].scales;
+        __m128i scales128 = _mm_set_epi32(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = _mm_sub_epi8(scales128, m32);
+        const __m128i scales_0 = _mm_cvtepi8_epi16(scales128);
+        const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128));
+        const __m128i scales[2] = { scales_0, scales_1 };
+
+        // high bit *128*2 from block_q3_K.hmask[QK_K/8]
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]);
+
+        // integer accumulator
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4]
+            const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+            const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16;
+
+            // prepare low and high bits
+            const int bit = j << 2;
+
+            const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3);
+            const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3);
+            const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2);
+            const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2);
+
+            const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3);
+            const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3);
+            const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+            const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2);
+
+            const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3);
+            const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3);
+            const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+            const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2);
+
+            const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3);
+            const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3);
+            const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+            const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2);
+
+            // load Q8 quants from block_q8_K.qs[QK_K]
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0);
+            __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1);
+            __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2);
+            __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3);
+            __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4);
+            __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5);
+            __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6);
+            __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7);
+
+            __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7);
+
+            p16_0 = _mm_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm_sub_epi16(p16_3, q8s_3);
+            p16_4 = _mm_sub_epi16(p16_4, q8s_4);
+            p16_5 = _mm_sub_epi16(p16_5, q8s_5);
+            p16_6 = _mm_sub_epi16(p16_6, q8s_6);
+            p16_7 = _mm_sub_epi16(p16_7, q8s_7);
+
+            // multiply with scales
+            __m128i shuffle = _mm_set1_epi16(0x0100);
+            p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7);
+
+            // accumulate
+            p16_0 = _mm_add_epi32(p16_0, p16_1);
+            p16_2 = _mm_add_epi32(p16_2, p16_3);
+            p16_4 = _mm_add_epi32(p16_4, p16_5);
+            p16_6 = _mm_add_epi32(p16_6, p16_7);
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6));
+
+        }
+
+        // multiply with block scale and accumulate
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict qh = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+        for (int j = 0; j < 16; ++j) scale[j] -= 32;
+
+
+        size_t vl = 32;
+        uint8_t m =  1;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl);
+
+        int sum_t = 0;
+
+        for (int j = 0; j < QK_K; j += 128) {
+
+            vl = 32;
+
+            // load Q3
+            vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl);
+
+            vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl));
+            vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl));
+            vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl));
+            vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl));
+
+            // compute mask for subtraction
+            vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl);
+            vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl);
+            m <<= 1;
+
+            vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl);
+            vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl);
+            m <<= 1;
+
+            // load Q8 and take product with Q3
+            vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            // retrieve lane to multiply with scale
+            vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
+            vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
+            vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
+            vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl);
+            vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl);
+            vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl);
+            vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl);
+            vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl);
+
+            sum_t +=  __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q3 += 32;    q8 += 128;   scale += 8;
+
+        }
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        sumf += d*sum_t;
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0x3);
+    const vector signed char lowMask1 = vec_splats((int8_t)0xf);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector signed char v1 = vec_splats((signed char)0x1);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector signed char off = vec_splats((signed char)0x20);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(u0, lowMask1);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2));
+        vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4);
+        vector signed char u31 = vec_and(u3, lowMask2);
+
+        u1 = vec_or(u1, u30);
+        u2 = vec_or(vec_sr(u0, v4), u31);
+
+        vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2);
+        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask);
+        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask);
+
+        vscales = vec_sub(vscales, off);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q3);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q3);
+            q3 += 32;
+
+            //the low 2 bits
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask);
+            vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask);
+            vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask);
+            vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask);
+            vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask);
+
+            //the 3rd bit
+            vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2);
+            vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2);
+            vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2);
+            vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2);
+            vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2);
+            vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2);
+            vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2);
+            vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2);
+            qxhs0 = vec_sr(qxhs0, v4);
+            qxhs1 = vec_sr(qxhs1, v4);
+
+            vector signed char q3x00 = vec_sub(qxs00, qxh00);
+            vector signed char q3x01 = vec_sub(qxs01, qxh01);
+            vector signed char q3x02 = vec_sub(qxs02, qxh02);
+            vector signed char q3x03 = vec_sub(qxs03, qxh03);
+            vector signed char q3x10 = vec_sub(qxs10, qxh10);
+            vector signed char q3x11 = vec_sub(qxs11, qxh11);
+            vector signed char q3x12 = vec_sub(qxs12, qxh12);
+            vector signed char q3x13 = vec_sub(qxs13, qxh13);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y02 = vec_xl( 64, q8);
+            vector signed char q8y12 = vec_xl( 80, q8);
+            vector signed char q8y03 = vec_xl( 96, q8);
+            vector signed char q8y13 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed short vscales_h = vec_unpackh(vscales);
+            vector signed short vs0 = vec_splat(vscales_h, 0);
+            vector signed short vs1 = vec_splat(vscales_h, 1);
+            vector signed short vs2 = vec_splat(vscales_h, 2);
+            vector signed short vs3 = vec_splat(vscales_h, 3);
+            vector signed short vs4 = vec_splat(vscales_h, 4);
+            vector signed short vs5 = vec_splat(vscales_h, 5);
+            vector signed short vs6 = vec_splat(vscales_h, 6);
+            vector signed short vs7 = vec_splat(vscales_h, 7);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00));
+            vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01));
+            vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02));
+            vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03));
+            vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10));
+            vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11));
+            vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12));
+            vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13));
+
+            vsumi0 = vec_msum(qv00, vs0, vsumi0);
+            vsumi1 = vec_msum(qv01, vs2, vsumi1);
+            vsumi2 = vec_msum(qv02, vs4, vsumi2);
+            vsumi3 = vec_msum(qv03, vs6, vsumi3);
+            vsumi4 = vec_msum(qv10, vs1, vsumi4);
+            vsumi5 = vec_msum(qv11, vs3, vsumi5);
+            vsumi6 = vec_msum(qv12, vs5, vsumi6);
+            vsumi7 = vec_msum(qv13, vs7, vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
+    const __m256i mone = __lasx_xvreplgr2vr_b(1);
+    const __m128i m32 = __lsx_vreplgr2vr_b(32);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    uint32_t aux[3];
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * restrict q3 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        __m128i scales128 = lsx_set_w(
+                ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
+                ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
+                (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
+                (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
+        scales128 = __lsx_vsub_b(scales128, m32);
+        const __m256i all_scales = lasx_ext8_16(scales128);
+        const __m128i l_scales = lasx_extracti128(all_scales, 0);
+        const __m128i h_scales = lasx_extracti128(all_scales, 1);
+        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+        // high bit
+        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
+
+        // integer accumulator
+        __m256i sumi = __lasx_xvldi(0);
+
+        int bit = 0;
+        int is  = 0;
+        __m256i xvbit;
+
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            // load low 2 bits
+            const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            // prepare low and high bits
+            const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
+            const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
+            const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
+            const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit);
+            const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
+            const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
+            ++bit;
+
+            // load Q8 quants
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
+            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
+            // and 2 if the high bit was set)
+            __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
+            __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
+            __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
+            __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
+
+            __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
+            __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
+            __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
+            __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
+
+            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+
+            // multiply with scales
+            p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
+            p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
+            p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
+            p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+
+            // accumulate
+            p16_0 = __lasx_xvadd_w(p16_0, p16_1);
+            p16_2 = __lasx_xvadd_w(p16_2, p16_3);
+            sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
+        }
+        // multiply with block scale and accumulate
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+    // scalar version
+    // This function is written like this so the compiler can manage to vectorize most of it
+    // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
+    // manually vectorized version above. Every other version I tried would run at least 4 times slower.
+    // The ideal situation would be if we could just write the code once, and the compiler would
+    // automatically produce the best possible set of machine instructions, instead of us having to manually
+    // write vectorized versions for AVX, ARM_NEON, etc.
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    uint32_t auxs[4];
+    const int8_t * scales = (const int8_t*)auxs;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
+            for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
+            a += 32; m <<= 1;
+            q3 += 32;
+        }
+        a = aux8;
+
+        memcpy(auxs, x[i].scales, 12);
+        uint32_t tmp = auxs[2];
+        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+
+#endif
+
+}
+
+void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q4_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x2_t q4bytes;
+    ggml_int8x16x2_t q8bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[0], m4b));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8  (q4bits.val[1], m4b));
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+            sumi1 += vaddvq_s32(p1) * scales[2*j+0];
+
+            q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
+            q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
+            q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
+
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
+
+            sumi2 += vaddvq_s32(p2) * scales[2*j+1];
+        }
+
+        sumf += d * (sumi1 + sumi2);
+
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4l = _mm256_and_si256(q4bits, m4);
+            const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
+
+            const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
+            p16l = _mm256_madd_epi16(scale_l, p16l);
+
+            const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
+            p16h = _mm256_madd_epi16(scale_h, p16h);
+            const __m256i sumj = _mm256_add_epi32(p16l, p16h);
+
+            sumi = _mm256_add_epi32(sumi, sumj);
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i m2 = _mm_set1_epi8(0x2);
+
+    __m256 acc = _mm256_setzero_ps();
+    __m128 acc_m = _mm_setzero_ps();
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_0 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+            q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4l_1 = _mm_and_si128(q4bits, m4);
+            const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4);
+
+            const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_0 = _mm_add_epi32(sumi_0, p16l);
+            const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16l = _mm_maddubs_epi16(q4l_1, q8l_1);
+            p16l = _mm_madd_epi16(scale_l, p16l);
+            sumi_1 = _mm_add_epi32(sumi_1, p16l);
+
+            const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_0 = _mm_add_epi32(sumi_0, p16h);
+            const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            p16h = _mm_maddubs_epi16(q4h_1, q8h_1);
+            p16h = _mm_madd_epi16(scale_h, p16h);
+            sumi_1 = _mm_add_epi32(sumi_1, p16h);
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
+    acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
+
+    *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        size_t vl = 8;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums   = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8  = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t  prod   = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        vl = 32;
+
+        int32_t sum_1 = 0;
+        int32_t sum_2 = 0;
+
+        vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q4
+            vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl);
+
+            // load Q8 and multiply it with lower Q4 nibble
+            vint8m1_t  q8_0 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl));
+            vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl);
+            vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl);
+
+            sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0];
+
+            // load Q8 and multiply it with upper Q4 nibble
+            vint8m1_t  q8_1 = __riscv_vle8_v_i8m1(q8+32, vl);
+            vint8m1_t  q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl));
+            vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl);
+            vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl);
+
+            sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1];
+
+            q4 += 32;    q8 += 64;
+
+        }
+
+        sumf += d*(sum_1 + sum_2);
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((uint8_t)2);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+        UNUSED(kmask3);
+        UNUSED(utmp);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = vec_sr(u2, v4);
+
+        vector signed char u30 = u1;
+        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+        u1 = vec_and(u0, lowMask1);
+        u2 = vec_or(u30, u31);
+
+        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+        vector signed short vscales = vec_unpackh(utmps);
+        vector signed short q4xmins = vec_unpackl(utmps);
+        vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins);
+        vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins);
+
+        vector signed int prod0 = vec_mule(q4xmins0, q8ysums0);
+        vector signed int prod1 = vec_mule(q4xmins1, q8ysums1);
+        vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0);
+        vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; j+=2) {
+            __builtin_prefetch(q4, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+            vector signed char qxs2 = (vector signed char)vec_xl(32, q4);
+            vector signed char qxs3 = (vector signed char)vec_xl(48, q4);
+            q4 += 64;
+
+            vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask);
+            vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4);
+            vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask);
+            vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4);
+            vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask);
+            vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4);
+            vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask);
+            vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y01 = vec_xl( 32, q8);
+            vector signed char q8y11 = vec_xl( 48, q8);
+            vector signed char q8y20 = vec_xl( 64, q8);
+            vector signed char q8y30 = vec_xl( 80, q8);
+            vector signed char q8y21 = vec_xl( 96, q8);
+            vector signed char q8y31 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed int qv00 = vec_msum(q8y00, q4x00, v0);
+            vector signed int qv01 = vec_msum(q8y01, q4x01, v0);
+            vector signed int qv10 = vec_msum(q8y10, q4x10, v0);
+            vector signed int qv11 = vec_msum(q8y11, q4x11, v0);
+            vector signed int qv20 = vec_msum(q8y20, q4x20, v0);
+            vector signed int qv21 = vec_msum(q8y21, q4x21, v0);
+            vector signed int qv30 = vec_msum(q8y30, q4x30, v0);
+            vector signed int qv31 = vec_msum(q8y31, q4x31, v0);
+
+            vector signed int vscales_h = vec_unpackh(vscales);
+            vector signed int vs0 = vec_splat(vscales_h, 0);
+            vector signed int vs1 = vec_splat(vscales_h, 1);
+            vector signed int vs2 = vec_splat(vscales_h, 2);
+            vector signed int vs3 = vec_splat(vscales_h, 3);
+            vscales = vec_sld(vscales, vscales, 8);
+
+            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3);
+
+            vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+    GGML_UNUSED(kmask1);
+    GGML_UNUSED(kmask2);
+    GGML_UNUSED(kmask3);
+
+    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+    __m128 acc_m = (__m128)__lsx_vldi(0);
+
+   for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
+
+        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
+        const __m256i scales = lasx_insertf128(sc128, sc128);
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+            const __m256i q4l = __lasx_xvand_v(q4bits, m4);
+            const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
+
+            const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            __m256i p16l = lasx_maddubs_h(q4l, q8l);
+            p16l = lasx_madd_h(scale_l, p16l);
+
+            const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            __m256i p16h = lasx_maddubs_h(q4h, q8h);
+            p16h = lasx_madd_h(scale_h, p16h);
+            const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
+
+            sumi = __lasx_xvadd_w(sumi, sumj);
+        }
+
+        __m256 vd = __lasx_xvreplfr2vr_s(d);
+        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+
+    }
+
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee));
+    __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0);
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
+
+
+    ft_union fi;
+    fi.i = __lsx_vpickve2gr_w(acc_m, 0);
+    *s = hsum_float_8(acc) + fi.f ;
+#else
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            a += 32;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            a += 32; q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy,  size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q5_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    static const uint32_t kmask1 = 0x3f3f3f3f;
+    static const uint32_t kmask2 = 0x0f0f0f0f;
+    static const uint32_t kmask3 = 0x03030303;
+
+    uint32_t utmp[4];
+
+#ifdef __ARM_NEON
+    const uint8x16_t m4b = vdupq_n_u8(0xf);
+    const uint8x16_t mone = vdupq_n_u8(1);
+    const uint8x16_t mtwo = vdupq_n_u8(2);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x4_t q5bytes;
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        int32_t sumi_mins = vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
+
+        ggml_uint8x16x4_t q5h;
+
+        int32_t sumi = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
+            const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
+            q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
+            qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
+            qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
+
+            q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
+            q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
+            q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
+            q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
+
+            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
+            sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
+        }
+
+        sumf += d * sumi - dmin * sumi_mins;
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m256i mone  = _mm256_set1_epi8(1);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
+        const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+        const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i sc128  = _mm256_extracti128_si256(mins_and_scales, 0);
+        const __m256i scales = MM256_SET_M128I(sc128, sc128);
+
+        const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
+        __m256i hmask = mone;
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int bit = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
+
+            const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
+            const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
+            const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
+            const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);
+            hmask = _mm256_slli_epi16(hmask, 1);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
+
+            p16_0 = _mm256_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm256_madd_epi16(scale_1, p16_1);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __AVX__
+
+    const __m128i m4 = _mm_set1_epi8(0xF);
+    const __m128i mzero = _mm_setzero_si128();
+    const __m128i mone  = _mm_set1_epi8(1);
+    const __m128i m2 = _mm_set1_epi8(2);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    float summs = 0.f;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i scales = _mm_cvtepu8_epi16(utmps);
+        const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps));
+
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]);
+        const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1);
+        const __m128i prod = _mm_madd_epi16(mins, q8s);
+        const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
+        summs += dmin * _mm_extract_epi32(hsum, 0);
+
+        const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]);
+        const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]);
+        __m128i hmask = mone;
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        int bit = 0;
+
+        __m128i shuffle = _mm_set1_epi16(0x0100);
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle);
+            shuffle = _mm_add_epi16(shuffle, m2);
+
+            const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+            const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16;
+
+            __m128i q5l_0 = _mm_and_si128(q5bits_0, m4);
+            __m128i q5l_1 = _mm_and_si128(q5bits_1, m4);
+            __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            __m128i q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            __m128i q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_0 = _mm_madd_epi16(scale_0, p16_0);
+            p16_1 = _mm_madd_epi16(scale_0, p16_1);
+
+            q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4);
+            q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4);
+            q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4);
+            q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4);
+            q5_0  = _mm_add_epi8(q5l_0, q5h_0);
+            q5_1  = _mm_add_epi8(q5l_1, q5h_1);
+            hmask = _mm_slli_epi16(hmask, 1);
+
+            q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0);
+            __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1);
+            p16_2 = _mm_madd_epi16(scale_1, p16_2);
+            p16_3 = _mm_madd_epi16(scale_1, p16_3);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+
+        }
+
+        __m256 vd = _mm256_set1_ps(d);
+        __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#elif defined __riscv_v_intrinsic
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    float sumf = 0;
+    float sums = 0.0;
+
+    size_t vl;
+
+    for (int i = 0; i < nb; ++i) {
+
+        vl = 8;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+
+        vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl);
+        vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl);
+        vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl);
+        vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl));
+        vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl);
+
+        vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
+        sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi);
+
+        vl = 32;
+        int32_t aux32 = 0;
+        int is = 0;
+
+        uint8_t m = 1;
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+        vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl);
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // load Q5 and Q8
+            vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl);
+            vint8m1_t  q8_y1 = __riscv_vle8_v_i8m1(q8, vl);
+            vint8m1_t  q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl);
+
+            // compute mask for addition
+            vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl));
+            vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl);
+            vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl);
+            m <<= 1;
+
+            vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl));
+            vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl);
+            vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl);
+            vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl);
+            m <<= 1;
+
+            vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl);
+            vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl);
+
+            vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl);
+            vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl);
+
+            vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl);
+            vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl);
+
+            aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2);
+            q5 += 32;    q8 += 64;
+
+        }
+
+        vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1);
+        sums += __riscv_vfmv_f_s_f32m1_f32(vaux);
+
+    }
+
+    *s = sumf+sums;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed char lowMask1 = vec_splats((int8_t)0x3f);
+    const vector signed char lowMask2 = vec_splats((int8_t)0x30);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v1 = vec_splats((unsigned char)0x1);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin));
+        vector float vdmin = vec_mul(vxmin, vyd);
+
+        UNUSED(kmask1);
+        UNUSED(kmask2);
+        UNUSED(kmask3);
+        UNUSED(utmp);
+
+        vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8);
+        vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2);
+        vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4);
+        vector signed char u3 = vec_sr(u2, v4);
+
+        vector signed char u30 = u1;
+        vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3);
+
+        u1 = vec_and(u0, lowMask1);
+        u2 = vec_or(u30, u31);
+
+        vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2);
+
+        vector signed short q8ysums0 = vec_xl( 0, y[i].bsums);
+        vector signed short q8ysums1 = vec_xl(16, y[i].bsums);
+
+        vector signed short vscales = vec_unpackh(utmps);
+
+        vector signed short q5xmins = vec_unpackl(utmps);
+        vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins);
+        vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins);
+
+        vector signed int prod0 = vec_mule(q5xmins0, q8ysums0);
+        vector signed int prod1 = vec_mule(q5xmins1, q8ysums1);
+        vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0);
+        vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1);
+
+        vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0);
+        vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1);
+        vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2);
+        vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3);
+
+        vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh);
+        vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            __builtin_prefetch(q5, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q5);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q5);
+            q5 += 32;
+
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_sr(qxs0, v4);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_sr(qxs1, v4);
+
+            vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4);
+            vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3);
+            vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4);
+            vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3);
+            qxhs0 = vec_sr(qxhs0, v2);
+            qxhs1 = vec_sr(qxhs1, v2);
+
+            vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00);
+            vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01);
+            vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10);
+            vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11);
+
+            vector signed char q8y00 = vec_xl( 0, q8);
+            vector signed char q8y10 = vec_xl(16, q8);
+            vector signed char q8y01 = vec_xl(32, q8);
+            vector signed char q8y11 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed int qv00 = vec_msum(q8y00, q5x00, v0);
+            vector signed int qv01 = vec_msum(q8y01, q5x01, v0);
+            vector signed int qv10 = vec_msum(q8y10, q5x10, v0);
+            vector signed int qv11 = vec_msum(q8y11, q5x11, v0);
+
+            vector signed int vscales_h = vec_unpackh(vscales);
+            vector signed int vs0 = vec_splat(vscales_h, 0);
+            vector signed int vs1 = vec_splat(vscales_h, 1);
+            vscales = vec_sld(vscales, vscales, 12);
+
+            vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0);
+            vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1);
+            vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2);
+            vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+    GGML_UNUSED(kmask1);
+    GGML_UNUSED(kmask2);
+    GGML_UNUSED(kmask3);
+
+    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+    const __m128i mzero = __lsx_vldi(0);
+    const __m256i mone  = __lasx_xvreplgr2vr_b(1);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    float summs = 0.f;
+
+   for (int i = 0; i < nb; ++i) {
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+
+        const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
+        const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
+        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+        const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
+        summs += dmin * __lsx_vpickve2gr_w(hsum, 0);    //TODO check
+
+        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
+        const __m256i scales = lasx_insertf128(sc128, sc128);
+
+        const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
+        __m256i hmask = mone;
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        int bit = 0;
+        __m256i xvbit;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+
+            const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
+            const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+
+            const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
+
+            xvbit = __lasx_xvreplgr2vr_h(bit++);
+            const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
+            const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
+            const __m256i q5_0  = __lasx_xvadd_b(q5l_0, q5h_0);
+            hmask = __lasx_xvslli_h(hmask, 1);
+
+            xvbit = __lasx_xvreplgr2vr_h(bit++);
+            const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
+            const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
+            const __m256i q5_1  = __lasx_xvadd_b(q5l_1, q5h_1);
+            hmask = __lasx_xvslli_h(hmask, 1);
+
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
+            __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
+
+            p16_0 = lasx_madd_h(scale_0, p16_0);
+            p16_1 = lasx_madd_h(scale_1, p16_1);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+
+        }
+
+        __m256 vd = __lasx_xvreplfr2vr_s(d);
+        acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc);
+
+    }
+
+    *s = hsum_float_8(acc) + summs;
+
+#else
+
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    const uint8_t * mins   = (const uint8_t*)&utmp[2];
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].qs;
+        const uint8_t * restrict hm = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K/64; ++j) {
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l]  >> 4);
+            for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
+            a += 32; m <<= 1;
+            q4 += 32;
+        }
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        int sumi = 0;
+        for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            int32_t scale = scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf -= dmin * sumi;
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_q6_K * restrict x = vx;
+    const block_q8_K * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#ifdef __ARM_NEON
+    float sum = 0;
+
+    const uint8x16_t m4b = vdupq_n_u8(0xF);
+    const int32x4_t  vzero = vdupq_n_s32(0);
+    //const int8x16_t  m32s = vdupq_n_s8(32);
+
+    const uint8x16_t mone = vdupq_n_u8(3);
+
+    ggml_int8x16x4_t q6bytes;
+    ggml_uint8x16x4_t q6h;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
+        const int8x16_t scales = vld1q_s8(scale);
+        const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}};
+
+        const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
+                                         vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
+                                                   vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
+        int32_t isum_mins = vaddvq_s32(prod);
+
+        int32_t isum = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
+            ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
+            ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
+            uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 2);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+
+            scale += 4;
+
+            q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            shifted = vshrq_n_u8(qhbits.val[0], 4);
+            q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 4);
+            q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[0], 6);
+            q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+            shifted = vshrq_n_u8(qhbits.val[1], 6);
+            q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
+
+            //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
+            //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
+            //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
+            //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
+            q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
+            q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
+            q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
+            q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
+
+            isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
+                    vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
+            scale += 4;
+        }
+        //sum += isum * d_all * y[i].d;
+        sum += d_all * y[i].d * (isum - 32 * isum_mins);
+
+    }
+    *s = sum;
+
+#elif defined __AVX2__
+
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(3);
+    const __m256i m32s = _mm256_set1_epi8(32);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+
+        __m256i sumi = _mm256_setzero_si256();
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
+            const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
+
+            const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
+            const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
+            const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
+            const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
+
+            const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
+            const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
+            const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
+            const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
+
+            const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
+            __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
+            __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
+            __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
+
+            __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
+            __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
+            __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
+            __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
+
+            p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
+            p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
+            p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
+            p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
+
+            p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
+            p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
+            p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
+
+        }
+
+        acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __AVX__
+
+    const __m128i m3 = _mm_set1_epi8(3);
+    const __m128i m15 = _mm_set1_epi8(15);
+
+    __m256 acc = _mm256_setzero_ps();
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // handle the q6_k -32 offset separately using bsums
+        const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums);
+        const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1);
+        const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
+        const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales);
+        const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8));
+        const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5);
+        const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5);
+
+        __m128i sumi_0 = _mm_setzero_si128();
+        __m128i sumi_1 = _mm_setzero_si128();
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+            const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16;
+
+            const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4);
+            const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4);
+            const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2);
+            const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2);
+            const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48));
+            const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48));
+            const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2);
+            const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2);
+
+            const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+            const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16;
+
+            const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0);
+            const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1);
+            const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2);
+            const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3);
+            const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4);
+            const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5);
+            const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6);
+            const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7);
+
+            const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+            const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16;
+
+            __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0);
+            __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1);
+            __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2);
+            __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3);
+            __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4);
+            __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5);
+            __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6);
+            __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7);
+
+            const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0);
+            p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1);
+            p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2);
+            p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3);
+            p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4);
+            p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5);
+            p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6);
+            p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7);
+
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3));
+            sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6));
+            sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7));
+
+        }
+
+        sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0);
+        sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1);
+        const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0);
+        acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#elif defined __riscv_v_intrinsic
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+
+        const int8_t * restrict scale = x[i].scales;
+
+        size_t vl;
+
+        vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
+
+        int sum_t = 0;
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            vl = 32;
+
+            // load qh
+            vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl);
+
+            // load Q6
+            vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl);
+            vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl);
+
+            vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl);
+            vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl);
+            vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl);
+            vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl);
+
+            vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl);
+            vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl);
+            vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl);
+            vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl);
+
+            vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl);
+            vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl);
+            vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl);
+            vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl);
+
+            vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl);
+            vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl);
+            vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl);
+            vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl);
+
+            // load Q8 and take product
+            vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl);
+            vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl);
+            vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl);
+            vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl);
+
+            vl = 16;
+
+            vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl);
+            vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl);
+            vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl);
+            vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl);
+            vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl);
+            vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl);
+            vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl);
+            vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl);
+
+            vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl);
+            vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl);
+            vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl);
+            vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl);
+
+            sum_t += __riscv_vmv_x_s_i32m1_i32(isum3);
+
+            q6 += 64;   qh += 32;   q8 += 128;   is=8;
+
+        }
+
+        sumf += d * sum_t;
+
+    }
+
+    *s = sumf;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v2 = vec_splats((unsigned char)0x2);
+    const vector unsigned char v3 = vec_splats((unsigned char)0x3);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+    const vector unsigned char v6 = vec_splats((unsigned char)0x6);
+    const vector signed char off = vec_splats((signed char)0x20);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+        vector signed int vsumi4 = v0;
+        vector signed int vsumi5 = v0;
+        vector signed int vsumi6 = v0;
+        vector signed int vsumi7 = v0;
+
+        const uint8_t * restrict q6 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict qs = x[i].scales;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+            __builtin_prefetch(q6, 0, 0);
+            __builtin_prefetch(qh, 0, 0);
+            __builtin_prefetch(q8, 0, 0);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q6);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q6);
+            vector signed char qxs2 = (vector signed char)vec_xl(32, q6);
+            vector signed char qxs3 = (vector signed char)vec_xl(48, q6);
+            q6 += 64;
+
+            vector signed char qxs00 = vec_and(qxs0, lowMask);
+            vector signed char qxs01 = vec_sr(qxs0, v4);
+            vector signed char qxs10 = vec_and(qxs1, lowMask);
+            vector signed char qxs11 = vec_sr(qxs1, v4);
+            vector signed char qxs20 = vec_and(qxs2, lowMask);
+            vector signed char qxs21 = vec_sr(qxs2, v4);
+            vector signed char qxs30 = vec_and(qxs3, lowMask);
+            vector signed char qxs31 = vec_sr(qxs3, v4);
+
+            vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh);
+            vector signed char qxhs1 = (vector signed char)vec_xl(16, qh);
+            qh += 32;
+
+            vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4);
+            vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4);
+            vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4);
+            vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4);
+            vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4);
+            vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4);
+            vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4);
+            vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4);
+
+            vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off);
+            vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off);
+            vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off);
+            vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off);
+            vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off);
+            vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off);
+            vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off);
+            vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off);
+
+            vector signed char q8y00 = vec_xl(  0, q8);
+            vector signed char q8y10 = vec_xl( 16, q8);
+            vector signed char q8y20 = vec_xl( 32, q8);
+            vector signed char q8y30 = vec_xl( 48, q8);
+            vector signed char q8y01 = vec_xl( 64, q8);
+            vector signed char q8y11 = vec_xl( 80, q8);
+            vector signed char q8y21 = vec_xl( 96, q8);
+            vector signed char q8y31 = vec_xl(112, q8);
+            q8 += 128;
+
+            vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00));
+            vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10));
+            vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20));
+            vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30));
+            vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01));
+            vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11));
+            vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21));
+            vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31));
+
+            vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8));
+            qs += 8;
+
+            vector signed short vs0 = vec_splat(vscales, 0);
+            vector signed short vs1 = vec_splat(vscales, 1);
+            vector signed short vs2 = vec_splat(vscales, 2);
+            vector signed short vs3 = vec_splat(vscales, 3);
+            vector signed short vs4 = vec_splat(vscales, 4);
+            vector signed short vs5 = vec_splat(vscales, 5);
+            vector signed short vs6 = vec_splat(vscales, 6);
+            vector signed short vs7 = vec_splat(vscales, 7);
+
+            vsumi0 = vec_msum(qv00, vs0, vsumi0);
+            vsumi1 = vec_msum(qv01, vs4, vsumi1);
+            vsumi2 = vec_msum(qv10, vs1, vsumi2);
+            vsumi3 = vec_msum(qv11, vs5, vsumi3);
+            vsumi4 = vec_msum(qv20, vs2, vsumi4);
+            vsumi5 = vec_msum(qv21, vs6, vsumi5);
+            vsumi6 = vec_msum(qv30, vs3, vsumi6);
+            vsumi7 = vec_msum(qv31, vs7, vsumi7);
+        }
+
+        vsumi0 = vec_add(vsumi0, vsumi4);
+        vsumi1 = vec_add(vsumi1, vsumi5);
+        vsumi2 = vec_add(vsumi2, vsumi6);
+        vsumi3 = vec_add(vsumi3, vsumi7);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined __loongarch_asx
+
+    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
+    const __m256i m2 = __lasx_xvreplgr2vr_b(3);
+    const __m256i m32s = __lasx_xvreplgr2vr_b(32);
+
+    __m256 acc = (__m256)__lasx_xvldi(0);
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+
+        __m256i sumi = __lasx_xvldi(0);
+
+        int is = 0;
+
+        for (int j = 0; j < QK_K/128; ++j) {
+
+            const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
+            const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
+            const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
+            const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
+            is += 4;
+
+            const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+            const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
+            const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
+
+            const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
+            const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
+            const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
+            const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
+
+            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
+            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
+            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
+            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
+
+            const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
+            __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
+            __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
+            __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
+
+            __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
+            __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
+            __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
+            __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
+
+            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
+            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
+            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
+            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+
+            p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
+            p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
+            p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
+            p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
+        }
+
+        acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
+    }
+
+    *s = hsum_float_8(acc);
+
+#else
+
+    int8_t  aux8[QK_K];
+    int16_t aux16[8];
+    float   sums [8];
+    int32_t aux32[8];
+    memset(sums, 0, 8*sizeof(float));
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const  int8_t * restrict q8 = y[i].qs;
+        memset(aux32, 0, 8*sizeof(int32_t));
+        int8_t * restrict a = aux8;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            }
+            a  += 128;
+            q4 += 64;
+            qh += 32;
+        }
+        a = aux8;
+        int is = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            int scale = x[i].scales[is++];
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+            for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
+            for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
+            q8 += 8; a += 8;
+        }
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
+    }
+    for (int l = 0; l < 8; ++l) sumf += sums[l];
+    *s = sumf;
+#endif
+}
+
+#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx)
+static const int8_t keven_signs_q2xs[1024] = {
+     1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1,  1,
+     1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1,  1,  1, -1, -1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1, -1,
+     1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1,  1,  1, -1,  1, -1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1, -1,
+     1,  1, -1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1,  1,
+     1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1, -1,
+     1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1,  1,
+     1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1,  1,
+     1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1,  1,  1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1, -1,
+     1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1,  1, -1,  1,  1, -1, -1,  1,  1,  1, -1,  1, -1,
+     1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1,  1,
+     1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1,  1,
+     1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1, -1,
+     1,  1,  1,  1, -1, -1,  1,  1, -1,  1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1,  1,
+     1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1, -1,
+     1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1, -1,
+     1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1,  1,
+     1,  1,  1,  1,  1,  1, -1, -1, -1,  1,  1,  1,  1,  1, -1,  1,  1, -1,  1,  1,  1,  1, -1,  1, -1, -1,  1,  1,  1,  1, -1, -1,
+     1,  1, -1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1, -1,  1, -1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1,  1,
+     1,  1,  1, -1,  1,  1, -1,  1, -1,  1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1,  1,
+     1,  1, -1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1, -1,
+     1,  1,  1,  1, -1,  1, -1,  1, -1,  1,  1,  1, -1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1,  1,
+     1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1,  1,  1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1, -1,
+     1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1, -1,  1,  1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1, -1,
+     1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1,
+     1,  1,  1,  1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1, -1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1,  1,
+     1,  1, -1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1,  1, -1, -1,  1, -1, -1, -1,  1,  1, -1, -1, -1,
+     1,  1,  1, -1,  1, -1, -1, -1, -1,  1,  1, -1,  1, -1, -1,  1,  1, -1,  1, -1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
+     1,  1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1,
+     1,  1,  1,  1, -1, -1, -1, -1, -1,  1,  1,  1, -1, -1, -1,  1,  1, -1,  1,  1, -1, -1, -1,  1, -1, -1,  1,  1, -1, -1, -1, -1,
+     1,  1, -1,  1, -1, -1, -1,  1, -1,  1, -1,  1, -1, -1, -1, -1,  1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1,  1,
+     1,  1,  1, -1, -1, -1, -1,  1, -1,  1,  1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1,  1,
+     1,  1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1,  1,  1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1,
+};
+#endif
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xxs * restrict x = vx;
+    const block_q8_K    * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    ggml_int8x16x4_t q2u;
+    ggml_int8x16x4_t q2s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        float sumf1 = 0, sumf2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1])));
+            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3])));
+            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9])));
+            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11])));
+            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
+            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >>  7) & 127))));
+            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127))));
+            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]);
+            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28));
+            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28));
+        }
+        sumf += d*(sumf1 + sumf2);
+    }
+    *s = 0.25f * sumf;
+
+#elif defined(__AVX2__)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]);
+            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    const vector int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            uint32_t aux32[4];
+            const uint8_t * aux8 = (const uint8_t *)aux32;
+
+            memcpy(aux32, q2, 4*sizeof(uint32_t));
+            q2 += 8;
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])};
+
+            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127))};
+            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))};
+            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127))};
+            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))};
+
+            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = aux32[1] >> 28;
+            const uint16_t ls1 = aux32[3] >> 28;
+
+            vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[4];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
+
+            const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
+            const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
+            const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
+                                                   signs64[(aux32[3] >>  7) & 127], signs64[(aux32[3] >>  0) & 127]);
+            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
+            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[1] >> 28;
+            const uint16_t ls2 = aux32[3] >> 28;
+            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+    uint32_t aux32[2];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(aux32, q2, 2*sizeof(uint32_t));
+            q2 += 4;
+            const uint32_t ls = 2*(aux32[1] >> 28) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+                for (int j = 0; j < 8; ++j) {
+                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += sumi * ls;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.125f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_xs * restrict x = vx;
+    const block_q8_K   * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    ggml_int8x16x4_t q2u;
+    ggml_int8x16x4_t q2s;
+    ggml_int8x16x4_t q8b;
+
+    int32x4x4_t scales32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+        const uint8x8_t scales8 = vld1_u8(x[i].scales);
+        const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
+        const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
+        uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
+        scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
+        const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
+        const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
+        scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
+        scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
+        scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
+        scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
+        int32x4_t sumi = vdupq_n_s32(0);
+        for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
+            q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
+            q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
+            q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511))));
+            q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9))));
+            q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9))));
+            q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9))));
+            q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9))));
+            q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]);
+            q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]);
+            q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]);
+            q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]);
+            const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]);
+            const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
+            const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
+            const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
+            const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
+            sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
+            q2 += 8;
+        }
+        sumf += d*vaddvq_s32(sumi);
+    }
+    *s = 0.125f * sumf;
+
+#elif defined(__AVX2__)
+
+    const __m256i mone = _mm256_set1_epi8(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes);
+    const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1);
+    const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper);
+    const __m256i m511 = _mm256_set1_epi16(511);
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = _mm_set1_epi64x(aux64);
+        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2);  q2 += 16;
+            aux_gindex = _mm256_and_si256(q2_data, m511);
+
+            const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9);
+            const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13);
+            const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper);
+
+            const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting);
+            const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits);
+
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+
+            const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
+                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
+            const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
+                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
+            const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
+                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
+            const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
+                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+
+            const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits);
+            const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1);
+            const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l);
+            const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h);
+
+            __m256i signs;
+            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone));
+
+            signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone));
+
+            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone));
+
+            signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2);
+            signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone));
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const __m256i dot3  = _mm256_maddubs_epi16(q2_3, q8s_3);
+            const __m256i dot4  = _mm256_maddubs_epi16(q2_4, q8s_4);
+
+            const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)));
+            const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)));
+            const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)));
+            const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)));
+
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2));
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4));
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const __m128i mone = _mm_set1_epi8(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes);
+    const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1);
+    const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1);
+    const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1);
+    const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2);
+    const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper);
+    const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1);
+    const __m128i m511 = _mm_set1_epi16(511);
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = _mm_set1_epi64x(aux64);
+        stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4));
+        const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1);
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2);
+            const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1);  q2 += 16;
+            aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511));
+
+            const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9);
+            const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9);
+            const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13);
+            const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13);
+            const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0);
+            const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1);
+
+            const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0);
+            const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1);
+            const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0);
+            const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1);
+
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]);
+            const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]);
+            const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]);
+            const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+            const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]);
+
+            // AVX2 full_signs_1 is full_sign_bits_0 here
+            // AVX2 full_signs_2 is full_sign_bits_1 here
+            __m128i signs_0, signs_1;
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone));
+
+            signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0);
+            signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1);
+            signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0);
+            signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1);
+            const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone));
+            const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone));
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const __m128i dot3_0  = _mm_maddubs_epi16(q2_3_0, q8s_3_0);
+            const __m128i dot3_1  = _mm_maddubs_epi16(q2_3_1, q8s_3_1);
+            const __m128i dot4_0  = _mm_maddubs_epi16(q2_4_0, q8s_4_0);
+            const __m128i dot4_1  = _mm_maddubs_epi16(q2_4_1, q8s_4_1);
+
+            __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0));
+            const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1));
+            const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2));
+            const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+            sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3));
+            const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp);
+            const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8));
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1));
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__loongarch_asx)
+
+    const __m256i mone = __lasx_xvreplgr2vr_b(1);
+    static const char block_sign_shuffle_mask_1[32] = {
+        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
+        0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06,
+    };
+    static const char block_sign_shuffle_mask_2[32] = {
+        0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a,
+        0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e,
+    };
+    static const uint8_t bit_selector_mask_bytes[32] = {
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0);
+    const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0);
+    const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0);
+
+    static const uint8_t k_bit_helper[32] = {
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+        0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00,
+    };
+    const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0);
+    const __m256i m511 = __lasx_xvreplgr2vr_h(511);
+    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
+    const __m128i m1 = __lsx_vreplgr2vr_b(1);
+
+    uint64_t aux64;
+
+    // somewhat hacky, but gives a significant boost in performance
+    __m256i aux_gindex;
+    const uint16_t * gindex = (const uint16_t *)&aux_gindex;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        __m128i stmp = __lsx_vreplgr2vr_d(aux64);
+        stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4));
+        const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1);
+
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) {
+
+            const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0);  q2 += 16;
+            aux_gindex = __lasx_xvand_v(q2_data, m511);
+
+            const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9);
+            const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13);
+            const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper);
+
+            const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting);
+            const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits);
+
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+
+            const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]],
+                                                   iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]);
+            const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]],
+                                                   iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]);
+            const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]],
+                                                   iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]);
+            const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]],
+                                                   iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]);
+
+            const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0);
+            const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1);
+            const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l);
+            const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h);
+
+            __m256i signs;
+            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1);
+
+            signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2);
+
+            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3);
+
+            signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2);
+            signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask);
+            const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4);
+
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const __m256i dot3  = lasx_maddubs_h(q2_3, q8s_3);
+            const __m256i dot4  = lasx_maddubs_h(q2_4, q8s_4);
+
+            const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0)));
+            const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1)));
+            const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2)));
+            const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3)));
+
+            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1));
+            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2));
+            sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3));
+            sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4));
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+#elif defined(__POWER9_VECTOR__)
+    const vector int v0 = vec_splats((int32_t)0);
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint16_t * restrict q2 = x[i].qs;
+        const uint8_t  * restrict sc = x[i].scales;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))};
+
+            vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))};
+            vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))};
+            vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))};
+            vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))};
+            q2 += 8;
+
+            vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0);
+            vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1);
+            vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2);
+            vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);
+            sc += 2;
+
+            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+            vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+#else
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint16_t * restrict q2 = x[i].qs;
+        const uint8_t  * restrict sc = x[i].scales;
+        const int8_t   * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1;
+            const uint16_t ls2 = 2*(sc[ib32] >>  4) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 2; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
+                for (int j = 0; j < 8; ++j) {
+                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += sumi * ls1;
+            sumi = 0;
+            for (int l = 2; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511));
+                const uint8_t  signs = ksigns_iq2xs[q2[l] >> 9];
+                for (int j = 0; j < 8; ++j) {
+                    sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += sumi * ls2;
+            q2 += 4;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.125f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq2_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
+    const uint8x16_t        mask2 = vld1q_u8(k_mask2);
+    const uint8x16_t m1 = vdupq_n_u8(1);
+    const int32x4_t vzero = vdupq_n_s32(0);
+
+    uint8x16x2_t vs;
+    ggml_int8x16x4_t q2s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300)))));
+            q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300)))));
+            q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300)))));
+            q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))),
+                                     vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300)))));
+            qs += 8;
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vceqq_u8(vs.val[0], mask2);
+            vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+            q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]);
+            q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]);
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vceqq_u8(vs.val[0], mask2);
+            vs.val[1] = vceqq_u8(vs.val[1], mask2);
+
+            signs += 4;
+
+            q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]);
+            q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]);
+
+            const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]);
+            const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]);
+            const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]);
+            const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]);
+
+            sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf));
+            sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >>  4));
+            sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf));
+            sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >>  4));
+        }
+        sumf += d*(sumi1 + sumi2);
+    }
+
+    *s = 0.125f * sumf;
+
+#elif defined(__AVX2__)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+    uint64_t aux64;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+        const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
+                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
+                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            qs += 8;
+
+            __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+            aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
+
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0)));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1)));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i m4 = _mm_set1_epi8(0xf);
+    const __m128i m1 = _mm_set1_epi8(1);
+
+    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+    uint64_t aux64;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        memcpy(&aux64, x[i].scales, 8);
+        const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1);
+        const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8);
+        const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8));
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                  iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                  iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]);
+            const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                  iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                  iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]);
+            qs += 8;
+
+            __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16));
+            __m128i aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+            aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16));
+            aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+            signs += 4;
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0)));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1)));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0)));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1)));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+    };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+    const vector unsigned char mask1 = vec_xl(16, k_mask1);
+    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t *  restrict q2 = x[i].qs;
+        const uint8_t *  restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const uint8_t *  restrict sc = x[i].scales;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q2, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))};
+            q2 += 8;
+            qh += 2;
+
+            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+            vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+            signs += 4;
+
+            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+            vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0);
+            vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1);
+
+            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+            vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0);
+            vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1);
+            vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2);
+            vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            const uint16_t ls2 = (uint16_t)(sc[1] & 0xf);
+            const uint16_t ls3 = (uint16_t)(sc[1] >>  4);
+            sc += 2;
+
+            vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1));
+            vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1));
+            vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1));
+            vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1));
+
+            vsumi0 = vec_msum(qv0, vscales0, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales1, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales2, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales3, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.125f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+
+    const __m128i m4 = __lsx_vreplgr2vr_b(0xf);
+    const __m128i m1 = __lsx_vreplgr2vr_b(1);
+
+    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
+    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
+    uint64_t aux64;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8);
+        const int8_t  * restrict q8 = y[i].qs;
+
+        __m128i tmp1;
+        memcpy(&aux64, x[i].scales, 8);
+        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0);
+        tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1);
+        const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1);
+        const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15
+
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)],
+                                                   iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)],
+                                                   iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)],
+                                                   iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]);
+            const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)],
+                                                   iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)],
+                                                   iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)],
+                                                   iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]);
+            qs += 8;
+
+            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
+
+            aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3
+
+            const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0)));
+            const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1)));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = 0.125f * hsum_float_8(accumf);
+
+#else
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const int8_t  * q8 = y[i].qs;
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+        const uint8_t * signs = qs + QK_K/8;
+
+        int bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf);
+            int ls2 = 1 + 2*(x[i].scales[ib32] >>  4);
+            int sumi1 = 0, sumi2 = 0;
+            for (int l = 0; l < 2; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            for (int l = 2; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            bsum += ls1 * sumi1 + ls2 * sumi2;
+            qs += 4;
+            signs += 4;
+        }
+
+        sumf += d * bsum;
+    }
+
+    *s = 0.125f * sumf;
+
+#endif
+
+}
+
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_xxs * restrict x = vx;
+    const block_q8_K    * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    ggml_int8x16x4_t q3s;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t   * restrict q8 = y[i].qs;
+        float sumf1 = 0, sumf2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+            memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t);
+            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]);
+            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]);
+            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]);
+            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]);
+            q3 += 16;
+            q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >>  7) & 127))));
+            q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127))));
+            q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >>  0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >>  7) & 127))));
+            q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127))));
+            q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0));
+            q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1));
+            q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2));
+            q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3));
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+            sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28));
+            sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28));
+        }
+        sumf += d*(sumf1 + sumf2);
+    }
+    *s = 0.5f * sumf;
+
+#elif defined(__AVX2__)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                  iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+            const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
+            const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+            q3 += 8;
+            const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+            const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]);
+            const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]);
+            const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0);
+            const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1);
+            const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0);
+            const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1);
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4);
+        const int8_t  * restrict q8 = y[i].qs;
+
+#pragma GCC unroll 1
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]};
+            vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]};
+            vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]};
+            vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]};
+            q3 += 16;
+
+            vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >>  0) & 127]), (uint64_t)(signs64[(signs[0] >>  7) & 127])};
+            vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])};
+            vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >>  0) & 127]), (uint64_t)(signs64[(signs[1] >>  7) & 127])};
+            vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])};
+
+            vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0);
+            vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1);
+            vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2);
+            vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(signs[0] >> 28);
+            const uint16_t ls1 = (uint16_t)(signs[1] >> 28);
+            signs += 2;
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = 0.25f * vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+
+    uint32_t aux32[2];
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]],
+                                                iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]);
+            q3 += 8;
+            memcpy(aux32, gas, 8); gas += 8;
+
+            const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127],
+                                                   signs64[(aux32[0] >>  7) & 127], signs64[(aux32[0] >>  0) & 127]);
+            const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
+                                                   signs64[(aux32[1] >>  7) & 127], signs64[(aux32[1] >>  0) & 127]);
+            const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1);
+            const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2);
+            const __m256i dot1  = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const uint16_t ls1 = aux32[0] >> 28;
+            const uint16_t ls2 = aux32[1] >> 28;
+
+            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = 0.25f * hsum_float_8(accumf);
+
+#else
+
+    uint32_t aux32;
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict gas = x[i].qs + QK_K/4;
+        const int8_t  * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t);
+            const uint32_t ls = 2*(aux32 >> 28) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]);
+                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]);
+                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            q3 += 8;
+            bsum += sumi * ls;
+        }
+        sumf += d * bsum;
+    }
+    *s = 0.25f * sumf;
+#endif
+}
+
+void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq3_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined(__ARM_NEON)
+
+    typedef union {
+        uint16x8_t vec_index;
+        uint16_t   index[8];
+    } vec_index_t;
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
+
+    const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1);
+    const uint8x16_t        mask2 = vld1q_u8(k_mask2);
+
+    const int16x8_t  hshift = vld1q_s16(k_shift);
+    const uint16x8_t m256   = vdupq_n_u16(256);
+    const uint8x16_t m1     = vdupq_n_u8(1);
+
+    uint8x16x2_t vs;
+    ggml_int8x16x4_t q3s;
+    ggml_int8x16x4_t q8b;
+    vec_index_t idx;
+
+    uint32_t scales32[2];
+    const uint8_t * scales8 = (const uint8_t *)scales32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t   * restrict q8 = y[i].qs;
+
+        memcpy(scales32, x[i].scales, 4);
+        scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101;
+        scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const uint8x16_t idx_l = vld1q_u8(qs); qs += 16;
+            idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256));
+            const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
+                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
+            const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
+                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
+            idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256));
+            const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]],
+                                                        iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]);
+            const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]],
+                                                        iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]);
+
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
+            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
+
+            q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0));
+            q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1));
+
+            vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16)));
+            vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
+            vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2);
+            vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1);
+            vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1);
+
+            signs += 4;
+
+            q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2));
+            q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3));
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
+
+            sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0];
+            sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4];
+        }
+        sumf += d*(sumi1 + sumi2);
+    }
+    *s = sumf;
+
+#elif defined(__AVX2__)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
+    const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
+
+    const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
+    const __m256i idx_mask  = _mm256_set1_epi32(256);
+
+    typedef union {
+        __m256i  vec[2];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16;
+            idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]);
+            idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]);
+            idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask);
+            idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask);
+            idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l)));
+            idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1)));
+
+            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
+            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
+            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
+            const __m256i q2_1 = _mm256_set_epi32(
+                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
+                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
+            );
+            const __m256i q2_2 = _mm256_set_epi32(
+                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
+                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
+            );
+
+            __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1);
+
+            aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16));
+            aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2);
+            const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2);
+            const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1  = _mm256_maddubs_epi16(q2_1, q8s_1);
+            const __m256i dot2  = _mm256_maddubs_epi16(q2_2, q8s_2);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1));
+            sumi1 = _mm256_add_epi32(sumi1, p1);
+            sumi2 = _mm256_add_epi32(sumi2, p2);
+        }
+
+        accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);
+
+    }
+
+    *s = hsum_float_8(accumf);
+
+#elif defined(__AVX__)
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1);
+    const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1);
+    const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2);
+    const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1);
+
+    const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256);
+    const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16);
+    const __m128i idx_mask  = _mm_set1_epi32(256);
+
+    typedef union {
+        __m128i  vec[4];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs);
+            const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp);
+            const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16;
+            idx.vec[0] = _mm_set1_epi32(qh[ib32+0]);
+            idx.vec[1] = idx.vec[0];
+            idx.vec[2] = _mm_set1_epi32(qh[ib32+1]);
+            idx.vec[3] = idx.vec[2];
+
+            idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask);
+            idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask);
+            idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask);
+            idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask);
+
+            idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0));
+            idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8)));
+            idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1));
+            idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8)));
+
+            const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]);
+            const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]);
+            const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]);
+            const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]);
+
+            __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16));
+            __m128i aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0);
+            const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1);
+
+            aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16));
+            aux128_1 = aux128_0;
+            aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0);
+            aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1);
+            const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0);
+            const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1);
+            const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0);
+            const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1);
+
+            signs += 4;
+
+            const __m128i dot1_0  = _mm_maddubs_epi16(q2_1_0, q8s_1_0);
+            const __m128i dot1_1  = _mm_maddubs_epi16(q2_1_1, q8s_1_1);
+            const __m128i dot2_0  = _mm_maddubs_epi16(q2_2_0, q8s_2_0);
+            const __m128i dot2_1  = _mm_maddubs_epi16(q2_2_1, q8s_2_1);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1));
+            sumi1_0 = _mm_add_epi32(sumi1_0, p1_0);
+            sumi1_1 = _mm_add_epi32(sumi1_1, p1_1);
+            sumi2_0 = _mm_add_epi32(sumi2_0, p2_0);
+            sumi2_1 = _mm_add_epi32(sumi2_1, p2_1);
+        }
+
+        accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf);
+
+    }
+
+    *s = hsum_float_8(accumf);
+
+#elif defined(__POWER9_VECTOR__)
+    static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                        0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+    };
+
+    static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,};
+
+    const vector int v0 = vec_splats((int32_t)0);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector unsigned char mask0 = vec_xl( 0, k_mask1);
+    const vector unsigned char mask1 = vec_xl(16, k_mask1);
+    const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        const uint8_t *  restrict q3 = x[i].qs;
+        const uint8_t *  restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)(x[i].signs);
+        const uint8_t *  restrict sc = x[i].scales;
+        const int8_t  *  restrict q8 = y[i].qs;
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q3, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)],
+                                             iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]};
+            vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)],
+                                             iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]};
+            vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)],
+                                             iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]};
+            vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)],
+                                             iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]};
+            q3 += 16;
+            qh += 2;
+
+            vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]);
+            vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]);
+            signs += 4;
+
+            vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0);
+            vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1);
+            vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0);
+            vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1);
+
+            vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2);
+            vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2);
+            vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2);
+            vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2);
+
+            vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0);
+            vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1);
+            vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2);
+            vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(sc[0] & 0xf);
+            const uint16_t ls1 = (uint16_t)(sc[0] >>  4);
+            sc ++;
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+   static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
+                                       0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03
+   };
+
+    static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+                                        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
+    };
+
+    const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0);
+    const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0);
+
+    __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8);
+    const __m256i idx_mask  = __lasx_xvreplgr2vr_w(256);
+
+    typedef union {
+        __m256i  vec[2];
+        uint32_t index[16];
+    } index_t;
+
+    index_t idx;
+
+    __m256 accumf = (__m256)__lasx_xvldi(0);
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint16_t * restrict signs = (const uint16_t *)x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16;
+            idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]);
+            idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]);
+            idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask);
+            idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask);
+            idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0)));
+            idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1)));
+
+            // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange.
+            //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4);
+            //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4);
+            const __m256i q2_1 = lasx_set_w(
+                    iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]],
+                    iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]
+            );
+            const __m256i q2_2 = lasx_set_w(
+                    iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]],
+                    iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]]
+            );
+
+            __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1);
+
+            aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16));
+            aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2);
+            const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2);
+            const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2);
+
+            signs += 4;
+
+            const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1);
+            const __m256i dot2  = lasx_maddubs_h(q2_2, q8s_2);
+            const uint16_t ls1 = x[i].scales[ib32/2] & 0xf;
+            const uint16_t ls2 = x[i].scales[ib32/2] >>  4;
+            const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1));
+            const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1));
+            sumi1 = __lasx_xvadd_w(sumi1, p1);
+            sumi2 = __lasx_xvadd_w(sumi2, p2);
+        }
+
+        accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf);
+    }
+
+    *s = hsum_float_8(accumf);
+
+#else
+
+    float sumf = 0.f;
+    for (int i = 0; i < nb; ++i) {
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const uint8_t * restrict qs = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const uint8_t * restrict signs = x[i].signs;
+        const int8_t  * restrict q8 = y[i].qs;
+        int32_t bsum = 0;
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1;
+            const uint32_t ls2 = 2*(x[i].scales[ib32/2] >>  4) + 1;
+            int32_t sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            qs += 8;
+            signs += 4;
+            bsum += sumi * ls1;
+            sumi = 0;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1);
+                    sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1);
+                }
+                q8 += 8;
+            }
+            qs += 8;
+            signs += 4;
+            bsum += sumi * ls2;
+        }
+        sumf += d * bsum;
+    }
+    *s = sumf;
+#endif
+}
+
+#if defined(__AVX2__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return _mm256_maddubs_epi16(ax, sy);
+}
+#elif defined(__loongarch_asx)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = __lasx_xvsigncov_b(x, x);
+    const __m256i sy = __lasx_xvsigncov_b(x, y);
+    __m256i tmp1, tmp2, tmp3;
+    tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
+    tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
+    tmp3 = __lasx_xvadd_h(tmp1, tmp2);
+    return __lasx_xvsat_h(tmp3, 15);
+}
+#endif
+
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_s * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+
+    ggml_int8x16x4_t q1b;
+    ggml_int8x16x4_t q8b;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        int sumi1 = 0, sumi2 = 0, sumi3 = 0;
+
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
+            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
+            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
+            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
+            qs += 8;
+
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
+            const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
+
+            const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            sumi1 += vaddvq_s32(p1) * ls1;
+            sumi2 += vaddvq_s32(p2) * ls2;
+            sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
+
+        }
+
+        sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    __m256 accum = _mm256_setzero_ps();
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m256i sumi = _mm256_setzero_si256();
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
+                                                    iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+            const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)],
+                                                    iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+            qs += 8;
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
+            const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
+
+            sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
+        accum1 += d * sumi1;
+
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#elif defined __AVX__
+    __m256 accum = _mm256_setzero_ps();
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
+            const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]);
+            const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]);
+            const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]);
+            qs += 8;
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1));
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1));
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2));
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2));
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum);
+        accum1 += d * sumi1;
+
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#elif defined(__POWER9_VECTOR__)
+    const vector unsigned char v0 = vec_splats((unsigned char)0x0);
+    const vector unsigned short vsign = vec_splats((unsigned short)0x8000);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    for (int i = 0; i < nb; ++i) {
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d));
+        vector float vyd = vec_splats(y[i].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = vec_splats((int32_t)0);
+        vector signed int vsumi1 = vec_splats((int32_t)0);
+        vector signed int vsumi2 = vec_splats((int32_t)0);
+        vector signed int vsumi3 = vec_splats((int32_t)0);
+        vector signed int vsumi8 = vec_splats((int32_t)0);
+
+        const uint8_t  * restrict q1 = x[i].qs;
+        const uint16_t * restrict qh = x[i].qh;
+        const int8_t   * restrict q8 = y[i].qs;
+        const int16_t  * restrict qs = y[i].bsums;
+
+        for (int j = 0; j < QK_K/32; j += 2) {
+            __builtin_prefetch(q1, 0, 1);
+            __builtin_prefetch(qh, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))};
+            vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))};
+            vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))};
+            vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))};
+            q1 += 8;
+
+            vector signed char q1x0 = (vector signed char)aux64x2_0;
+            vector signed char q1x1 = (vector signed char)aux64x2_1;
+            vector signed char q1x2 = (vector signed char)aux64x2_2;
+            vector signed char q1x3 = (vector signed char)aux64x2_3;
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3));
+
+            const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7);
+            const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7);
+
+            vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1));
+            vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1));
+            vector signed short vscales = vec_sld(vscales23, vscales01, 8);
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+
+            vector signed short q8ysums = vec_xl_len(qs, 8);
+            qs += 4;
+            q8ysums = vec_mergeh(q8ysums, (vector signed short)v0);
+
+            vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8);
+            qh += 2;
+            vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0);
+
+            vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel);
+
+            vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    __m256 accum = (__m256)__lasx_xvldi(0);
+    float accum1 = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        __m256i sumi = __lasx_xvldi(0);
+        int sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0);
+            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1);
+            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2);
+            q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3);
+
+            __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0);
+            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1);
+            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2);
+            q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3);
+
+            qs += 8;
+            const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+            const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+            const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
+            const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
+
+            __m256i tmp1, tmp5, tmp6;
+            tmp1 = __lasx_xvreplgr2vr_h(ls1);
+            tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1);
+            const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6);
+
+            tmp1 = __lasx_xvreplgr2vr_h(ls2);
+            tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1);
+            const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6);
+
+            sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2));
+            sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
+                   + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
+        }
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum);
+        accum1 += d * sumi1;
+    }
+
+    *s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
+
+#else
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        int sumi = 0, sumi1 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const int ls = 2*((qh[ib] >> 12) & 7) + 1;
+            const int delta = qh[ib] & 0x8000 ? -1 : 1;
+            int lsum = 0;
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
+                for (int j = 0; j < 8; ++j) {
+                    lsum += q8[j] * grid[j];
+                }
+                q8 += 8;
+            }
+            sumi  += ls * lsum;
+            sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
+            qs += 4;
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
+    }
+
+    *s = sumf;
+
+#endif
+}
+
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(n % QK_K == 0);
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    const block_iq1_m * restrict x = vx;
+    const block_q8_K  * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+    iq1m_scale_t scale;
+
+#if defined __ARM_NEON
+    const int32x4_t mask  = vdupq_n_s32(0x7);
+    const int32x4_t mone  = vdupq_n_s32(1);
+    const int32x4_t mzero = vdupq_n_s32(0);
+
+    ggml_int8x16x4_t deltas;
+    deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1));
+    deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1));
+    deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1));
+    deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1));
+
+    ggml_int8x16x4_t q1b;
+    ggml_int8x16x4_t q8b;
+
+    uint32_t aux32;
+    const uint8_t * aux8 = (const uint8_t *)&aux32;
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        int32x4_t sumi1 = mzero;
+        int32x4_t sumi2 = mzero;
+
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+
+            q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
+            q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
+            q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
+            q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
+                                     vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
+
+            q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
+            const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
+            const int32x4_t p12 = vpaddq_s32(p1, p2);
+
+            const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that
+            aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202);
+
+            const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1]));
+            const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3]));
+            const int32x4_t p34 = vpaddq_s32(p3, p4);
+
+            int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9);
+
+            scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
+
+            sumi1 = vmlaq_s32(sumi1, scales_4, p12);
+            sumi2 = vmlaq_s32(sumi2, scales_4, p34);
+
+            qs += 8; qh += 4;
+
+        }
+
+        sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m256i mask = _mm256_set1_epi16(0x7);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m256i q1b_1 = _mm256_set_epi64x(
+                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)],
+                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]
+            );
+            const __m256i q1b_2 = _mm256_set_epi64x(
+                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)],
+                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]
+            );
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
+
+            const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
+            const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
+
+            const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+            const __m256i dot3 = mul_add_epi8(delta1, q8b_1);
+            const __m256i dot4 = mul_add_epi8(delta2, q8b_2);
+
+            __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0));
+            __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6));
+
+            scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone);
+            scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone);
+            const __m256i p1 = _mm256_madd_epi16(dot1, scale1);
+            const __m256i p2 = _mm256_madd_epi16(dot2, scale2);
+            const __m256i p3 = _mm256_madd_epi16(dot3, scale1);
+            const __m256i p4 = _mm256_madd_epi16(dot4, scale2);
+
+            sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2));
+            sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4));
+
+            qs += 8; qh += 4;
+        }
+
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+        accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1);
+        accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2);
+    }
+
+    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#elif defined __AVX__
+    const __m128i mask = _mm_set1_epi16(0x7);
+    const __m128i mone = _mm_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (int i = 0; i < nb; ++i) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q1b_1_0 = _mm_set_epi64x(
+                    iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]);
+            const __m128i q1b_1_1 = _mm_set_epi64x(
+                    iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]);
+            const __m128i q1b_2_0 = _mm_set_epi64x(
+                    iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]);
+            const __m128i q1b_2_1 = _mm_set_epi64x(
+                    iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]);
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+
+            const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0);
+            const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1);
+            const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0);
+            const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1);
+
+            const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+            const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101,
+                                                     qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101);
+
+            const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0);
+            const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1);
+            const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0);
+            const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1);
+
+            __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0);
+            __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3);
+            __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6);
+            __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9);
+
+            scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone);
+            scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone);
+            scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone);
+            scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone);
+            const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0);
+            const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1);
+            const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0);
+            const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1);
+            const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0);
+            const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1);
+            const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0);
+            const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1);
+
+            sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0));
+            sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1));
+            sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0));
+            sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1));
+
+            qs += 8; qh += 4;
+        }
+
+        const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16));
+
+        accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1);
+        accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2);
+    }
+
+    *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2);
+
+#else
+
+    int sum1[2], sum2[2], delta[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; i++) {
+
+        const int8_t   * q8 = y[i].qs;
+        const uint8_t  * qs = x[i].qs;
+        const uint8_t  * qh = x[i].qh;
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            delta[0] = qh[0] & 0x08 ? -1 : 1;
+            delta[1] = qh[0] & 0x80 ? -1 : 1;
+            delta[2] = qh[1] & 0x08 ? -1 : 1;
+            delta[3] = qh[1] & 0x80 ? -1 : 1;
+            sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0;
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700)));
+                int lsum1 = 0, lsum2 = 0;
+                for (int j = 0; j < 8; ++j) {
+                    lsum1 += q8[j] * grid[j];
+                    lsum2 += q8[j];
+                }
+                q8 += 8;
+                sum1[l/2] += lsum1;
+                sum2[l/2] += lsum2*delta[l];
+            }
+
+            const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1;
+            const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1;
+
+            sumi1 += sum1[0] * ls1 + sum1[1] * ls2;
+            sumi2 += sum2[0] * ls1 + sum2[1] * ls2;
+            qs += 4;
+            qh += 2;
+        }
+
+        sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2);
+    }
+
+    *s = sumf;
+
+#endif
+}
+
+void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK4_NL == 0);
+    static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same");
+
+    const block_iq4_nl * restrict x = vx;
+    const block_q8_0   * restrict y = vy;
+
+    const int nb = n / QK4_NL;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    uint8x16x2_t q4bits;
+    int8x16x4_t q4b;
+    int8x16x4_t q8b;
+    int32x4_t prod_1, prod_2;
+
+    for (; ib + 1 < nb; ib += 2) {
+
+        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);
+        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);
+        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);
+        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);
+
+        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+        sumf +=
+            GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+            GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+    }
+
+#elif defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+                _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+                _mm256_cvtepi32_ps(p_2), accum2);
+    }
+
+    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+
+        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
+        const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector signed int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+
+    const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d));
+        vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d));
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs);
+        vector signed char q4x0 = vec_and(qxs, lowMask);
+        vector signed char q4x1 = vec_sr(qxs, v4);
+
+        q4x0 = vec_perm(values, values, (vector unsigned char)q4x0);
+        q4x1 = vec_perm(values, values, (vector unsigned char)q4x1);
+
+        vector signed char q8y0 = vec_xl( 0, y[ib].qs);
+        vector signed char q8y1 = vec_xl(16, y[ib].qs);
+
+        vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
+        vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+
+        vsumi0 = vec_sum4s(qv0, vsumi0);
+        vsumi1 = vec_sum4s(qv1, vsumi1);
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    sumf = vec_extract(vsumf0, 0);
+
+#elif defined (__loongarch_asx)
+
+    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
+    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
+    const __m256i mone = __lasx_xvreplgr2vr_h(1);
+
+    __m256 accum1 = (__m256)__lasx_xvldi(0);
+    __m256 accum2 = (__m256)__lasx_xvldi(0);
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0);
+        const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0);
+        const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0);
+        const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0);
+        const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)),
+                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b)));
+        const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)),
+                                              lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = lasx_madd_h(p16_1, mone);
+        const __m256i p_2 = lasx_madd_h(p16_2, mone);
+        accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
+                __lasx_xvffint_s_w(p_1), accum1);
+        accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
+                __lasx_xvffint_s_w(p_2), accum2);
+    }
+
+    sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
+
+#endif
+    for (; ib < nb; ++ib) {
+        const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
+        int sumi1 = 0, sumi2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            sumi1 += y[ib].qs[j+       0] * kvalues_iq4nl[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
+void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_K == 0);
+
+    const block_iq4_xs * restrict x = vx;
+    const block_q8_K   * restrict y = vy;
+
+    const int nb = n / QK_K;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_iq4nl);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    ggml_uint8x16x2_t q4bits;
+    ggml_int8x16x4_t q4b;
+    ggml_int8x16x4_t q8b;
+    int32x4_t prod_1, prod_2;
+
+    float sumf = 0;
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+
+        const int8_t  * q8 = y[ibl].qs;
+        const uint8_t * q4 = x[ibl].qs;
+        uint16_t h = x[ibl].scales_h;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/64; ++ib) {
+
+            q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
+            q8b    = ggml_vld1q_s8_x4(q8); q8 += 64;
+
+            q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+            q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+            q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+            q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+            prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+            prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+            int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32;
+            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;
+            h >>= 4;
+            sumi1 += vaddvq_s32(prod_1) * ls1;
+            sumi2 += vaddvq_s32(prod_2) * ls2;
+
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
+
+#elif defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m256i sumi1 = _mm256_setzero_si256();
+        __m256i sumi2 = _mm256_setzero_si256();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
+            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs);  qs += 16;
+            const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
+            const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+            const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                                  _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
+            const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
+            sumi1 = _mm256_add_epi32(p_1, sumi1);
+            sumi2 = _mm256_add_epi32(p_2, sumi2);
+        }
+        accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m128i sumi1_0 = _mm_setzero_si128();
+        __m128i sumi1_1 = _mm_setzero_si128();
+        __m128i sumi2_0 = _mm_setzero_si128();
+        __m128i sumi2_1 = _mm_setzero_si128();
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+            const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16;
+            const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16;
+            const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+            const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+            const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+            const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+            const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
+            const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
+            const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
+            const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1));
+            const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1));
+            const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2));
+            const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2));
+            sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0);
+            sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1);
+            sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0);
+            sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1);
+        }
+        __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0);
+        __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1);
+        accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#elif defined(__POWER9_VECTOR__)
+    const vector signed char lowMask = vec_splats((signed char)0xF);
+    const vector int v0 = vec_splats((int32_t)0);
+    const vector unsigned char v4 = vec_splats((unsigned char)0x4);
+
+    vector float vsumf0 = vec_splats(0.0f);
+    vector float vsumf1 = vec_splats(0.0f);
+    vector float vsumf2 = vec_splats(0.0f);
+    vector float vsumf3 = vec_splats(0.0f);
+
+    const vector signed char values = vec_xl( 0, kvalues_iq4nl);
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+
+        vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d));
+        vector float vyd = vec_splats(y[ibl].d);
+        vector float vd = vec_mul(vxd, vyd);
+
+        vector signed int vsumi0 = v0;
+        vector signed int vsumi1 = v0;
+        vector signed int vsumi2 = v0;
+        vector signed int vsumi3 = v0;
+
+        uint16_t h = x[ibl].scales_h;
+
+        const uint8_t * restrict q4 = x[ibl].qs;
+        const uint8_t * restrict sc = x[ibl].scales_l;
+        const int8_t  * restrict q8 = y[ibl].qs;
+
+        for (int ib = 0; ib < QK_K/64; ib ++ ) {
+            __builtin_prefetch(q4, 0, 1);
+            __builtin_prefetch(q8, 0, 1);
+
+            vector signed char qxs0 = (vector signed char)vec_xl( 0, q4);
+            vector signed char qxs1 = (vector signed char)vec_xl(16, q4);
+            q4 += 32;
+
+            vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask);
+            vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4);
+            vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask);
+            vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4);
+
+            q4x00 = vec_perm(values, values, (vector unsigned char)q4x00);
+            q4x01 = vec_perm(values, values, (vector unsigned char)q4x01);
+            q4x10 = vec_perm(values, values, (vector unsigned char)q4x10);
+            q4x11 = vec_perm(values, values, (vector unsigned char)q4x11);
+
+            vector signed char q8y0 = vec_xl( 0, q8);
+            vector signed char q8y1 = vec_xl(16, q8);
+            vector signed char q8y2 = vec_xl(32, q8);
+            vector signed char q8y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0));
+            vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1));
+            vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2));
+            vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3));
+
+            const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32);
+            const uint16_t ls1 = (uint16_t)(((sc[0] >>  4) | ((h << 2) & 0x30)) - 32);
+            h >>= 4;
+            sc ++;
+
+            vector signed short vscales01 = vec_splats((int16_t)ls0);
+            vector signed short vscales23 = vec_splats((int16_t)ls1);
+
+            vsumi0 = vec_msum(qv0, vscales01, vsumi0);
+            vsumi1 = vec_msum(qv1, vscales01, vsumi1);
+            vsumi2 = vec_msum(qv2, vscales23, vsumi2);
+            vsumi3 = vec_msum(qv3, vscales23, vsumi3);
+        }
+
+        vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0);
+        vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1);
+        vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2);
+        vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3);
+    }
+
+    vsumf0 = vec_add(vsumf0, vsumf2);
+    vsumf1 = vec_add(vsumf1, vsumf3);
+
+    vsumf0 = vec_add(vsumf0, vsumf1);
+
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
+    vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
+
+    *s = vec_extract(vsumf0, 0);
+
+#elif defined(__loongarch_asx)
+
+    const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
+    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
+
+    __m256 accum = (__m256)__lasx_xvldi(0);
+    __m256i tmp1;
+    __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
+
+    mask_8f = __lsx_vreplgr2vr_b(0x8f);
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        uint16_t sh = x[ibl].scales_h;
+        __m256i sumi1 = __lasx_xvldi(0);
+        __m256i sumi2 = __lasx_xvldi(0);
+        __m128i zero = __lsx_vldi(0);
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
+            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
+            const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
+            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp3 = __lsx_vand_v(tmp0, mask);
+            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
+
+            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp4 = __lsx_vand_v(tmp0, mask);
+            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
+
+            const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
+
+            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp3 = __lsx_vand_v(tmp0, mask);
+            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
+
+            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
+            tmp0 = __lsx_vori_b(tmp2, 0x10);
+            mask = __lsx_vsle_b(zero, tmp2);
+            tmp4 = __lsx_vand_v(tmp0, mask);
+            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
+
+            const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
+
+            const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+            const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+            const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            sh >>= 4;
+            __m256i tmp5, tmp6;
+            tmp1 = __lasx_xvreplgr2vr_h(ls1);
+            tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
+            const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
+            tmp1 = __lasx_xvreplgr2vr_h(ls2);
+            tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
+            tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
+            const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
+            sumi1 = __lasx_xvadd_w(p_1, sumi1);
+            sumi2 = __lasx_xvadd_w(p_2, sumi2);
+        }
+        accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
+                __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum);
+    }
+
+    *s = hsum_float_8(accum);
+
+#else
+    float sumf = 0;
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
+        uint16_t h = x[ibl].scales_h;
+        const uint8_t * qs = x[ibl].qs;
+        const int8_t  * q8 = y[ibl].qs;
+        for (int ib = 0; ib < QK_K/32; ib += 2) {
+            const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
+            const uint8_t ls2 = (x[ibl].scales_l[ib/2] >>  4) | ((h << 2) & 0x30);
+            h >>= 4;
+            const float d1 = d4d8*(ls1 - 32);
+            const float d2 = d4d8*(ls2 - 32);
+            int sumi1 = 0, sumi2 = 0;
+            for (int j = 0; j < 16; ++j) {
+                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
+            }
+            sumf += d1 * (sumi1 + sumi2);
+            qs += 16;
+            q8 += 32;
+            sumi1 = sumi2 = 0;
+            for (int j = 0; j < 16; ++j) {
+                sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
+                sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >>  4];
+            }
+            sumf += d2 * (sumi1 + sumi2);
+            qs += 16;
+            q8 += 32;
+        }
+    }
+    *s = sumf;
+#endif
+}
+
+// ============================ 4-bit non-linear quants
+
+void quantize_row_iq4_nl(const float * restrict x, void * restrict y, int64_t k) {
+    assert(k % QK4_NL == 0);
+    quantize_row_iq4_nl_ref(x, y, k);
+}
+
+void quantize_row_iq4_xs(const float * restrict x, void * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    quantize_iq4_xs(x, y, 1, k, NULL);
+}
diff --git a/llama/ggml-cpu-quants.h b/llama/ggml-cpu-quants.h
new file mode 100644
index 000000000..e2cdf03ed
--- /dev/null
+++ b/llama/ggml-cpu-quants.h
@@ -0,0 +1,89 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML CPU internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// Quantization
+void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
+// Dot product
+void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq2_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq1_m_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_iq3_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml-cpu-traits.cpp b/llama/ggml-cpu-traits.cpp
new file mode 100644
index 000000000..6d7ca0246
--- /dev/null
+++ b/llama/ggml-cpu-traits.cpp
@@ -0,0 +1,62 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-cpu-traits.h"
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+
+namespace ggml::cpu {
+tensor_traits::~tensor_traits() {}
+
+extra_buffer_type::~extra_buffer_type() {}
+}  // namespace ggml::cpu
+
+bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) {
+    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
+        if (extra && extra->context) {
+            auto buf_extra     = (ggml::cpu::extra_buffer_type *) extra->context;
+            auto tensor_traits = buf_extra->get_tensor_traits(op);
+            if (tensor_traits && tensor_traits->compute_forward(params, op)) {
+                return true;
+            }
+        }
+    }
+    return false;
+}
+
+bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) {
+    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
+        if (extra && extra->context) {
+            auto buf_extra     = (ggml::cpu::extra_buffer_type *) extra->context;
+            auto tensor_traits = buf_extra->get_tensor_traits(op);
+            if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) {
+                return true;
+            }
+        }
+    }
+    return false;
+}
diff --git a/llama/ggml-cpu-traits.h b/llama/ggml-cpu-traits.h
new file mode 100644
index 000000000..dcd7855fd
--- /dev/null
+++ b/llama/ggml-cpu-traits.h
@@ -0,0 +1,64 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+#include "ggml-backend-impl.h"
+#include "ggml-cpu-impl.h"
+#include "ggml.h"
+
+#ifdef __cplusplus
+#    include 
+extern "C" {
+#endif
+
+// return true if op part of extra "accelerator"
+bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op);
+bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size);
+
+#ifdef __cplusplus
+}
+
+namespace ggml::cpu {
+// register in tensor->extra
+class tensor_traits {
+  public:
+    virtual ~tensor_traits();
+    virtual bool work_size(int n_threads, const struct ggml_tensor * op, size_t & size)        = 0;
+    virtual bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) = 0;
+};
+
+class extra_buffer_type {
+  public:
+    virtual ~extra_buffer_type();
+    virtual bool            supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) = 0;
+    virtual tensor_traits * get_tensor_traits(const struct ggml_tensor * op)                   = 0;
+};
+}  // namespace ggml::cpu
+
+// implemented in ggml-cpu.cpp.
+std::vector & ggml_backend_cpu_get_extra_buffers_type();
+
+#endif
diff --git a/llama/ggml-cpu.c b/llama/ggml-cpu.c
new file mode 100644
index 000000000..272f03e31
--- /dev/null
+++ b/llama/ggml-cpu.c
@@ -0,0 +1,14207 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
+
+#include "ggml-backend-impl.h"
+#include "ggml-backend.h"
+#include "ggml-cpu-traits.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include "ggml-cpu-quants.h"
+#include "ggml-threading.h"
+#include "amx.h"
+#include "ggml.h"
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include  // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#if defined(__gnu_linux__)
+#include 
+#endif
+
+#ifdef GGML_USE_OPENMP
+#include 
+#endif
+
+#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
+#undef GGML_USE_LLAMAFILE
+#endif
+
+#ifdef GGML_USE_LLAMAFILE
+#include "llamafile/sgemm.h"
+#endif
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+
+// disable POSIX deprecation warnings
+// these functions are never going away, anyway
+#pragma warning(disable: 4996)
+
+// unreachable code because of multiple instances of code after GGML_ABORT
+#pragma warning(disable: 4702)
+#endif
+
+// Note: once we move threading into a separate C++ file
+// will use std::hardware_destructive_interference_size instead of hardcoding it here
+// and we'll use C++ attribute syntax.
+#define GGML_CACHE_LINE  64
+
+#if defined(__clang__) || defined(__GNUC__)
+#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE)))
+#endif
+
+#if defined(__has_feature)
+#if __has_feature(thread_sanitizer)
+#define GGML_TSAN_ENABLED 1
+#endif
+#else  // __has_feature
+#if defined(__SANITIZE_THREAD__)
+#define GGML_TSAN_ENABLED 1
+#endif
+#endif // __has_feature
+
+#define UNUSED GGML_UNUSED
+#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0)
+
+#if defined(GGML_USE_ACCELERATE)
+#include 
+#endif
+
+// floating point type used to accumulate sums
+typedef double ggml_float;
+
+#define GGML_GELU_FP16
+#define GGML_GELU_QUICK_FP16
+
+#define GGML_SOFT_MAX_UNROLL 4
+#define GGML_VEC_DOT_UNROLL  2
+#define GGML_VEC_MAD_UNROLL  32
+
+//
+// global data
+//
+
+// precomputed gelu table for f16 (128 KB)
+static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
+
+// precomputed quick gelu table for f16 (128 KB)
+static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
+
+#if defined(__ARM_ARCH)
+struct ggml_arm_arch_features_type {
+    int has_neon;
+    int has_dotprod;
+    int has_i8mm;
+    int has_sve;
+    int sve_cnt;
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
+#endif
+
+
+#if defined(_WIN32)
+
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include 
+
+#if defined(_MSC_VER) && !defined(__clang__)
+#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE))
+
+typedef volatile LONG atomic_int;
+typedef atomic_int atomic_bool;
+typedef atomic_int atomic_flag;
+
+#define ATOMIC_FLAG_INIT 0
+
+typedef enum {
+    memory_order_relaxed,
+    memory_order_consume,
+    memory_order_acquire,
+    memory_order_release,
+    memory_order_acq_rel,
+    memory_order_seq_cst
+} memory_order;
+
+static void atomic_store(atomic_int * ptr, LONG val) {
+    InterlockedExchange(ptr, val);
+}
+static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) {
+    // TODO: add support for explicit memory order
+    InterlockedExchange(ptr, val);
+}
+static LONG atomic_load(atomic_int * ptr) {
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) {
+    // TODO: add support for explicit memory order
+    return InterlockedCompareExchange(ptr, 0, 0);
+}
+static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) {
+    // TODO: add support for explicit memory order
+    return InterlockedExchangeAdd(ptr, inc);
+}
+static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
+    return InterlockedExchange(ptr, 1);
+}
+static void atomic_flag_clear(atomic_flag * ptr) {
+    InterlockedExchange(ptr, 0);
+}
+static void atomic_thread_fence(memory_order mo) {
+    MemoryBarrier();
+}
+#else // clang
+#include 
+#endif
+
+typedef HANDLE pthread_t;
+
+typedef DWORD thread_ret_t;
+static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
+    (void) unused;
+    HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
+    if (handle == NULL)
+    {
+        return EAGAIN;
+    }
+
+    *out = handle;
+    return 0;
+}
+
+static int pthread_join(pthread_t thread, void * unused) {
+    (void) unused;
+    int ret = (int) WaitForSingleObject(thread, INFINITE);
+    CloseHandle(thread);
+    return ret;
+}
+
+static int sched_yield (void) {
+    Sleep (0);
+    return 0;
+}
+#else
+
+#include 
+#include 
+#include 
+#if defined(__FreeBSD__)
+#include 
+#endif
+
+typedef void * thread_ret_t;
+
+#include 
+#include 
+#include 
+
+#endif
+
+typedef pthread_t ggml_thread_t;
+
+#if defined(__APPLE__)
+#include 
+#include 
+#include 
+#endif
+
+//
+// cache line
+//
+
+#if defined(__cpp_lib_hardware_interference_size)
+#define CACHE_LINE_SIZE hardware_destructive_interference_size
+#else
+#if defined(__POWER9_VECTOR__)
+#define CACHE_LINE_SIZE 128
+#else
+#define CACHE_LINE_SIZE 64
+#endif
+#endif
+
+static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
+
+
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
+
+static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_F32] = {
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
+        .vec_dot_type             = GGML_TYPE_F32,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_F16] = {
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f16,
+        .vec_dot_type             = GGML_TYPE_F16,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q4_0] = {
+        .from_float               = quantize_row_q4_0,
+        .vec_dot                  = ggml_vec_dot_q4_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
+    },
+    [GGML_TYPE_Q4_1] = {
+        .from_float               = quantize_row_q4_1,
+        .vec_dot                  = ggml_vec_dot_q4_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
+    },
+    [GGML_TYPE_Q5_0] = {
+        .from_float               = quantize_row_q5_0,
+        .vec_dot                  = ggml_vec_dot_q5_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q5_1] = {
+        .from_float               = quantize_row_q5_1,
+        .vec_dot                  = ggml_vec_dot_q5_1_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q8_0] = {
+        .from_float               = quantize_row_q8_0,
+        .vec_dot                  = ggml_vec_dot_q8_0_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+#if defined (__ARM_FEATURE_MATMUL_INT8)
+        .nrows                    = 2,
+#else
+        .nrows                    = 1,
+#endif
+    },
+    [GGML_TYPE_Q8_1] = {
+        .from_float               = quantize_row_q8_1,
+        .vec_dot_type             = GGML_TYPE_Q8_1,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q2_K] = {
+        .from_float               = quantize_row_q2_K,
+        .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q3_K] = {
+        .from_float               = quantize_row_q3_K,
+        .vec_dot                  = ggml_vec_dot_q3_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q4_K] = {
+        .from_float               = quantize_row_q4_K,
+        .vec_dot                  = ggml_vec_dot_q4_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q5_K] = {
+        .from_float               = quantize_row_q5_K,
+        .vec_dot                  = ggml_vec_dot_q5_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q6_K] = {
+        .from_float               = quantize_row_q6_K,
+        .vec_dot                  = ggml_vec_dot_q6_K_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ2_XXS] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq2_xxs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ2_XS] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq2_xs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ3_XXS] = {
+        // NOTE: from_float for iq3 and iq2_s was removed because these quants require initialization in ggml_quantize_init
+        //.from_float               = quantize_row_iq3_xxs,
+        .vec_dot                  = ggml_vec_dot_iq3_xxs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ3_S] = {
+        //.from_float               = quantize_row_iq3_s,
+        .vec_dot                  = ggml_vec_dot_iq3_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ2_S] = {
+        //.from_float               = quantize_row_iq2_s,
+        .vec_dot                  = ggml_vec_dot_iq2_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ1_S] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq1_s_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ1_M] = {
+        .from_float               = NULL,
+        .vec_dot                  = ggml_vec_dot_iq1_m_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ4_NL] = {
+        .from_float               = quantize_row_iq4_nl,
+        .vec_dot                  = ggml_vec_dot_iq4_nl_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_IQ4_XS] = {
+        .from_float               = quantize_row_iq4_xs,
+        .vec_dot                  = ggml_vec_dot_iq4_xs_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_Q8_K] = {
+        .from_float               = quantize_row_q8_K,
+    },
+    [GGML_TYPE_BF16] = {
+        .from_float               = (ggml_from_float_t) ggml_fp32_to_bf16_row,
+        .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_bf16,
+        .vec_dot_type             = GGML_TYPE_BF16,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_TQ1_0] = {
+        .from_float               = quantize_row_tq1_0,
+        .vec_dot                  = ggml_vec_dot_tq1_0_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+    [GGML_TYPE_TQ2_0] = {
+        .from_float               = quantize_row_tq2_0,
+        .vec_dot                  = ggml_vec_dot_tq2_0_q8_K,
+        .vec_dot_type             = GGML_TYPE_Q8_K,
+        .nrows                    = 1,
+    },
+};
+
+const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
+    return &type_traits_cpu[type];
+}
+
+//
+// simd mappings
+//
+
+// we define a common set of C macros which map to specific intrinsics based on the current architecture
+// we then implement the fundamental computation operations below using only these macros
+// adding support for new architectures requires to define the corresponding SIMD macros
+//
+// GGML_F32_STEP / GGML_F16_STEP
+//   number of elements to process in a single step
+//
+// GGML_F32_EPR / GGML_F16_EPR
+//   number of elements to fit in a single register
+//
+
+#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA)
+
+#define GGML_SIMD
+
+// F32 NEON
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              float32x4_t
+#define GGML_F32x4_ZERO         vdupq_n_f32(0.0f)
+#define GGML_F32x4_SET1(x)      vdupq_n_f32(x)
+#define GGML_F32x4_LOAD         vld1q_f32
+#define GGML_F32x4_STORE        vst1q_f32
+#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
+#define GGML_F32x4_ADD          vaddq_f32
+#define GGML_F32x4_MUL          vmulq_f32
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
+#define GGML_F32x4_REDUCE(res, x)                       \
+{                                                       \
+    int offset = GGML_F32_ARR >> 1;                     \
+    for (int i = 0; i < offset; ++i) {                  \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
+    }                                                   \
+    offset >>= 1;                                       \
+    for (int i = 0; i < offset; ++i) {                  \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
+    }                                                   \
+    offset >>= 1;                                       \
+    for (int i = 0; i < offset; ++i) {                  \
+        (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \
+    }                                                   \
+    (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 NEON
+
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    #define GGML_F16_STEP 32
+    #define GGML_F16_EPR  8
+
+    #define GGML_F16x8              float16x8_t
+    #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f)
+    #define GGML_F16x8_SET1(x)      vdupq_n_f16(x)
+    #define GGML_F16x8_LOAD(x)      vld1q_f16((const ggml_fp16_internal_t *)(x))
+    #define GGML_F16x8_STORE        vst1q_f16
+    #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c)
+    #define GGML_F16x8_ADD          vaddq_f16
+    #define GGML_F16x8_MUL          vmulq_f16
+    #define GGML_F16x8_REDUCE(res, x)                               \
+    do {                                                            \
+        int offset = GGML_F16_ARR >> 1;                             \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        offset >>= 1;                                               \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        offset >>= 1;                                               \
+        for (int i = 0; i < offset; ++i) {                          \
+            (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \
+        }                                                           \
+        const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \
+        const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \
+        (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
+    } while (0)
+
+    #define GGML_F16_VEC                GGML_F16x8
+    #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F16x8_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i])
+    #define GGML_F16_VEC_FMA            GGML_F16x8_FMA
+    #define GGML_F16_VEC_ADD            GGML_F16x8_ADD
+    #define GGML_F16_VEC_MUL            GGML_F16x8_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE
+#else
+    // if FP16 vector arithmetic is not supported, we use FP32 instead
+    // and take advantage of the vcvt_ functions to convert to/from FP16
+
+    #define GGML_F16_STEP 16
+    #define GGML_F16_EPR  4
+
+    #define GGML_F32Cx4              float32x4_t
+    #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f)
+    #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x)
+    #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x)))
+    #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y))
+    #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c)
+    #define GGML_F32Cx4_ADD          vaddq_f32
+    #define GGML_F32Cx4_MUL          vmulq_f32
+    #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE
+
+    #define GGML_F16_VEC                GGML_F32Cx4
+    #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO
+    #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1
+    #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p)
+    #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i])
+    #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA
+    #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD
+    #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL
+    #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE
+#endif
+
+#elif defined(__AVX512F__)
+
+#define GGML_SIMD
+
+// F32 AVX512
+
+#define GGML_F32_STEP 64
+#define GGML_F32_EPR  16
+
+#define GGML_F32x16         __m512
+#define GGML_F32x16_ZERO    _mm512_setzero_ps()
+#define GGML_F32x16_SET1(x) _mm512_set1_ps(x)
+#define GGML_F32x16_LOAD    _mm512_loadu_ps
+#define GGML_F32x16_STORE   _mm512_storeu_ps
+// _mm512_fmadd_ps is defined in AVX512F so no guard is required
+#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32x16_ADD     _mm512_add_ps
+#define GGML_F32x16_MUL     _mm512_mul_ps
+#define GGML_F32x16_REDUCE(res, x)                                    \
+do {                                                                  \
+    int offset = GGML_F32_ARR >> 1;                                   \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    offset >>= 1;                                                     \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    offset >>= 1;                                                     \
+    for (int i = 0; i < offset; ++i) {                                \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \
+    }                                                                 \
+    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                    \
+} while (0)
+
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x16
+#define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x16_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x16_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x16_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x16_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x16_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE
+
+// F16 AVX512
+
+// F16 AVX
+
+#define GGML_F16_STEP 64
+#define GGML_F16_EPR  16
+
+// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead
+
+#define GGML_F32Cx16             __m512
+#define GGML_F32Cx16_ZERO        _mm512_setzero_ps()
+#define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x)
+
+// unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F
+// so F16C guard isn't required
+#define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
+#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0))
+
+#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a)
+#define GGML_F32Cx16_ADD         _mm512_add_ps
+#define GGML_F32Cx16_MUL         _mm512_mul_ps
+#define GGML_F32Cx16_REDUCE(res, x)                               \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    res = (ggml_float) _mm512_reduce_add_ps(x[0]);                \
+} while (0)
+
+#define GGML_F16_VEC                GGML_F32Cx16
+#define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL
+
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE
+#elif defined(__AVX__)
+
+#define GGML_SIMD
+
+// F32 AVX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    _mm256_setzero_ps()
+#define GGML_F32x8_SET1(x) _mm256_set1_ps(x)
+#define GGML_F32x8_LOAD    _mm256_loadu_ps
+#define GGML_F32x8_STORE   _mm256_storeu_ps
+#if defined(__FMA__)
+    #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a)
+#endif
+#define GGML_F32x8_ADD     _mm256_add_ps
+#define GGML_F32x8_MUL     _mm256_mul_ps
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \
+    }                                                             \
+    const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \
+                                 _mm256_extractf128_ps(x[0], 1)); \
+    const __m128 t1 = _mm_hadd_ps(t0, t0);                        \
+    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 AVX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8             __m256
+#define GGML_F32Cx8_ZERO        _mm256_setzero_ps()
+#define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x)
+
+#if defined(__F16C__)
+// the  _mm256_cvt intrinsics require F16C
+#define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
+#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
+#else
+static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return _mm256_loadu_ps(tmp);
+}
+static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
+    float arr[8];
+
+    _mm256_storeu_ps(arr, y);
+
+    for (int i = 0; i < 8; i++)
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+}
+#define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y)
+#endif
+
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         _mm256_add_ps
+#define GGML_F32Cx8_MUL         _mm256_mul_ps
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__POWER9_VECTOR__)
+
+#define GGML_SIMD
+
+// F32 POWER9
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              vector float
+#define GGML_F32x4_ZERO         0.0f
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)              \
+{                                              \
+    int offset = GGML_F32_ARR >> 1;            \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    offset >>= 1;                              \
+    for (int i = 0; i < offset; ++i) {         \
+        x[i] = vec_add(x[i], x[offset+i]);     \
+    }                                          \
+    res = vec_extract(x[0], 0) +               \
+          vec_extract(x[0], 1) +               \
+          vec_extract(x[0], 2) +               \
+          vec_extract(x[0], 3);                \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 POWER9
+#define GGML_F16_STEP       GGML_F32_STEP
+#define GGML_F16_EPR        GGML_F32_EPR
+#define GGML_F16_VEC        GGML_F32x4
+#define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F16_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
+// Use vec_xl, not vec_ld, in case the load address is not aligned.
+#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \
+  vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \
+  vec_extract_fp32_from_shortl(vec_xl(0, p))
+#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i]
+#define GGML_F16_VEC_STORE(p, r, i)                             \
+  if (i & 0x1)                                                  \
+    vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \
+                                   r[i - GGML_ENDIAN_BYTE(0)]), \
+            0, p - GGML_F16_EPR)
+
+#elif defined(__wasm_simd128__)
+
+#define GGML_SIMD
+
+// F32 WASM
+
+#define GGML_F32_STEP 16
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              v128_t
+#define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f)
+#define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x)
+#define GGML_F32x4_LOAD         wasm_v128_load
+#define GGML_F32x4_STORE        wasm_v128_store
+#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a)
+#define GGML_F32x4_ADD          wasm_f32x4_add
+#define GGML_F32x4_MUL          wasm_f32x4_mul
+#define GGML_F32x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F32_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 WASM
+
+#define GGML_F16_STEP 16
+#define GGML_F16_EPR  4
+
+inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(p[0]);
+    tmp[1] = GGML_FP16_TO_FP32(p[1]);
+    tmp[2] = GGML_FP16_TO_FP32(p[2]);
+    tmp[3] = GGML_FP16_TO_FP32(p[3]);
+
+    return wasm_v128_load(tmp);
+}
+
+inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) {
+    float tmp[4];
+
+    wasm_v128_store(tmp, x);
+
+    p[0] = GGML_FP32_TO_FP16(tmp[0]);
+    p[1] = GGML_FP32_TO_FP16(tmp[1]);
+    p[2] = GGML_FP32_TO_FP16(tmp[2]);
+    p[3] = GGML_FP32_TO_FP16(tmp[3]);
+}
+
+#define GGML_F16x4             v128_t
+#define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f)
+#define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x)
+#define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x)
+#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y)
+#define GGML_F16x4_FMA         GGML_F32x4_FMA
+#define GGML_F16x4_ADD         wasm_f32x4_add
+#define GGML_F16x4_MUL         wasm_f32x4_mul
+#define GGML_F16x4_REDUCE(res, x)                  \
+{                                                  \
+    int offset = GGML_F16_ARR >> 1;                \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    offset >>= 1;                                  \
+    for (int i = 0; i < offset; ++i) {             \
+        x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \
+    }                                              \
+    res = wasm_f32x4_extract_lane(x[0], 0) +       \
+          wasm_f32x4_extract_lane(x[0], 1) +       \
+          wasm_f32x4_extract_lane(x[0], 2) +       \
+          wasm_f32x4_extract_lane(x[0], 3);        \
+}
+
+#define GGML_F16_VEC                GGML_F16x4
+#define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F16x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F16x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F16x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F16x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE
+
+#elif defined(__SSE3__)
+
+#define GGML_SIMD
+
+// F32 SSE
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    _mm_setzero_ps()
+#define GGML_F32x4_SET1(x) _mm_set1_ps(x)
+#define GGML_F32x4_LOAD    _mm_loadu_ps
+#define GGML_F32x4_STORE   _mm_storeu_ps
+#if defined(__FMA__)
+    // TODO: Does this work?
+    #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a)
+#else
+    #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a)
+#endif
+#define GGML_F32x4_ADD     _mm_add_ps
+#define GGML_F32x4_MUL     _mm_mul_ps
+#define GGML_F32x4_REDUCE(res, x)                                 \
+{                                                                 \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = _mm_add_ps(x[i], x[offset+i]);                     \
+    }                                                             \
+    const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \
+    res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \
+}
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 SSE
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return _mm_loadu_ps(tmp);
+}
+
+static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) {
+    float arr[4];
+
+    _mm_storeu_ps(arr, y);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        _mm_setzero_ps()
+#define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x)
+#define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         _mm_add_ps
+#define GGML_F32Cx4_MUL         _mm_mul_ps
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#elif defined(__loongarch_asx)
+
+#define GGML_SIMD
+
+// F32 LASX
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  8
+
+#define GGML_F32x8         __m256
+#define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0)
+#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x))
+#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0)
+#define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0)
+#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a)
+#define GGML_F32x8_ADD     __lasx_xvfadd_s
+#define GGML_F32x8_MUL     __lasx_xvfmul_s
+#define GGML_F32x8_REDUCE(res, x)                                 \
+do {                                                              \
+    int offset = GGML_F32_ARR >> 1;                               \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    offset >>= 1;                                                 \
+    for (int i = 0; i < offset; ++i) {                            \
+        x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \
+    }                                                             \
+    float *tmp_p = (float *)&x[0]; \
+    res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \
+} while (0)
+// TODO: is this optimal ?
+
+#define GGML_F32_VEC        GGML_F32x8
+#define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x8_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x8_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x8_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x8_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x8_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE
+
+// F16 LASX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  8
+
+// F16 arithmetic is not supported by AVX, so we use F32 instead
+
+#define GGML_F32Cx8          __m256
+#define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
+#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
+
+static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
+    float tmp[8];
+
+    for (int i = 0; i < 8; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return (__m256)__lasx_xvld(tmp, 0);
+}
+static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
+    float arr[8];
+
+    __lasx_xvst(y, arr, 0);
+
+    for (int i = 0; i < 8; i++) {
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+    }
+}
+#define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
+#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
+
+#define GGML_F32Cx8_FMA         GGML_F32x8_FMA
+#define GGML_F32Cx8_ADD         __lasx_xvfadd_s
+#define GGML_F32Cx8_MUL         __lasx_xvfmul_s
+#define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE
+
+#define GGML_F16_VEC                GGML_F32Cx8
+#define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1
+#define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA
+#define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD
+#define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE
+
+#elif defined(__loongarch_sx)
+
+#define GGML_SIMD
+
+// F32 LSX
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4         __m128
+#define GGML_F32x4_ZERO    __lsx_vldi(0)
+#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0)
+#define GGML_F32x4_STORE((x),(y))   __lsx_vst((y), (x), 0)
+#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a)
+#define GGML_F32x4_ADD     __lsx_vfadd_s
+#define GGML_F32x4_MUL     __lsx_vfmul_s
+#define GGML_F32x4_REDUCE(res, x)                                                     \
+{                                                                                     \
+    int offset = GGML_F32_ARR >> 1;                                                   \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    offset >>= 1;                                                                     \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    offset >>= 1;                                                                     \
+    for (int i = 0; i < offset; ++i) {                                                \
+        x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \
+    }                                                                                 \
+    __m128i tmp     = __lsx_vsrli_d((__m128i) x[0], 32);                              \
+    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]);                    \
+    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
+    const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88);                                     \
+    tmp             = __lsx_vsrli_d((__m128i) t0, 32);                                \
+    tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, t0);                      \
+    tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \
+    res             = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 LSX
+
+#define GGML_F16_STEP 32
+#define GGML_F16_EPR  4
+
+static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    tmp[0] = GGML_FP16_TO_FP32(x[0]);
+    tmp[1] = GGML_FP16_TO_FP32(x[1]);
+    tmp[2] = GGML_FP16_TO_FP32(x[2]);
+    tmp[3] = GGML_FP16_TO_FP32(x[3]);
+
+    return __lsx_vld(tmp, 0);
+}
+
+static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
+    float arr[4];
+
+    __lsx_vst(y, arr, 0);
+
+    x[0] = GGML_FP32_TO_FP16(arr[0]);
+    x[1] = GGML_FP32_TO_FP16(arr[1]);
+    x[2] = GGML_FP32_TO_FP16(arr[2]);
+    x[3] = GGML_FP32_TO_FP16(arr[3]);
+}
+
+#define GGML_F32Cx4             __m128
+#define GGML_F32Cx4_ZERO        __lsx_vldi(0)
+#define GGML_F32Cx4_SET1(x)     __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0)
+#define GGML_F32Cx4_LOAD(x)     __lsx_f16x4_load(x)
+#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y)
+#define GGML_F32Cx4_FMA         GGML_F32x4_FMA
+#define GGML_F32Cx4_ADD         __lsx_vfadd_s
+#define GGML_F32Cx4_MUL         __lsx_vfmul_s
+#define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE
+
+#define GGML_F16_VEC                 GGML_F32Cx4
+#define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO
+#define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1
+#define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p)
+#define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i])
+#define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA
+#define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD
+#define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
+#define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
+
+#endif
+
+// GGML_F32_ARR / GGML_F16_ARR
+//   number of registers to use per step
+#ifdef GGML_SIMD
+#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR)
+#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
+#endif
+
+//
+// Threading defs
+//
+
+typedef pthread_t          ggml_thread_t;
+
+#if defined(_WIN32)
+
+typedef CONDITION_VARIABLE ggml_cond_t;
+typedef SRWLOCK            ggml_mutex_t;
+
+#define ggml_mutex_init(m)   InitializeSRWLock(m)
+#define ggml_mutex_destroy(m)
+#define ggml_mutex_lock(m)   AcquireSRWLockExclusive(m)
+#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m)
+#define ggml_mutex_lock_shared(m)   AcquireSRWLockShared(m)
+#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m)
+
+#define ggml_cond_init(c)    InitializeConditionVariable(c)
+#define ggml_cond_destroy(c)
+#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED)
+#define ggml_cond_broadcast(c) WakeAllConditionVariable(c)
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
+#else
+
+typedef pthread_cond_t     ggml_cond_t;
+typedef pthread_mutex_t    ggml_mutex_t;
+
+#define ggml_mutex_init(m)          pthread_mutex_init(m, NULL)
+#define ggml_mutex_destroy(m)       pthread_mutex_destroy(m)
+#define ggml_mutex_lock(m)          pthread_mutex_lock(m)
+#define ggml_mutex_unlock(m)        pthread_mutex_unlock(m)
+#define ggml_mutex_lock_shared(m)   pthread_mutex_lock(m)
+#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m)
+
+#define ggml_lock_init(x)    UNUSED(x)
+#define ggml_lock_destroy(x) UNUSED(x)
+#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
+#define ggml_lock_lock(x)    _mm_pause()
+#else
+#define ggml_lock_lock(x)    UNUSED(x)
+#endif
+#define ggml_lock_unlock(x)  UNUSED(x)
+
+#define GGML_LOCK_INITIALIZER 0
+#define ggml_cond_init(c)      pthread_cond_init(c, NULL)
+#define ggml_cond_destroy(c)   pthread_cond_destroy(c)
+#define ggml_cond_wait(c, m)   pthread_cond_wait(c, m)
+#define ggml_cond_broadcast(c) pthread_cond_broadcast(c)
+
+#define ggml_thread_create pthread_create
+#define ggml_thread_join   pthread_join
+
+#endif
+
+// Threadpool def
+struct ggml_threadpool {
+    ggml_mutex_t mutex;       // mutex for cond.var
+    ggml_cond_t  cond;        // cond.var for waiting for new work
+
+    struct ggml_cgraph * cgraph;
+    struct ggml_cplan  * cplan;
+
+    // synchronization primitives
+    atomic_int n_graph;       // incremented when there is work to be done (i.e each graph)
+    atomic_int GGML_CACHE_ALIGN n_barrier;
+    atomic_int GGML_CACHE_ALIGN n_barrier_passed;
+    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+
+    // these are atomic as an annotation for thread-sanitizer
+    atomic_bool stop;         // Used for stopping the threadpool altogether
+    atomic_bool pause;        // Used for pausing the threadpool or individual threads
+    atomic_bool abort;        // Used for aborting processing of a graph
+
+    struct ggml_compute_state * workers;   // per thread state
+    int          n_threads_max; // number of threads in the pool
+    atomic_int   n_threads_cur; // number of threads used in the current graph
+
+    int32_t      prio;        // Scheduling priority
+    uint32_t     poll;        // Polling level (0 - no polling)
+
+    enum ggml_status ec;
+};
+
+// Per-thread state
+struct ggml_compute_state {
+#ifndef GGML_USE_OPENMP
+    ggml_thread_t thrd;
+    bool cpumask[GGML_MAX_N_THREADS];
+    int  last_graph;
+    bool pending;
+#endif
+    struct ggml_threadpool * threadpool;
+    int ith;
+};
+
+//
+// fundamental operations
+//
+
+inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    }
+inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
+
+inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    }
+inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        }
+inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           }
+inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; }
+inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           }
+inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        }
+inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       }
+inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   }
+inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
+
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
+   assert(nrc == 1);
+   UNUSED(nrc);
+   UNUSED(bx);
+   UNUSED(by);
+   UNUSED(bs);
+
+#if defined(GGML_SIMD)
+    float sumf = 0.0f;
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F32_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += x[i]*y[i];
+    }
+#else
+    // scalar
+    ggml_float sumf = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(x[i]*y[i]);
+    }
+#endif
+
+    *s = sumf;
+}
+
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    int i = 0;
+    ggml_float sumf = 0;
+
+#if defined(__AVX512BF16__)
+    __m512 c1 = _mm512_setzero_ps();
+    __m512 c2 = _mm512_setzero_ps();
+    for (; i + 64 <= n; i += 64) {
+        c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
+                             m512bh(_mm512_loadu_si512((y + i))));
+        c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
+                             m512bh(_mm512_loadu_si512((y + i + 32))));
+    }
+    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#elif defined(__AVX512F__)
+#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16))
+    __m512 c1 = _mm512_setzero_ps();
+    __m512 c2 = _mm512_setzero_ps();
+    for (; i + 32 <= n; i += 32) {
+        c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+        c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2);
+    }
+    sumf += (ggml_float)_mm512_reduce_add_ps(c1);
+    sumf += (ggml_float)_mm512_reduce_add_ps(c2);
+
+#undef LOAD
+#elif defined(__AVX2__) || defined(__AVX__)
+#if defined(__AVX2__)
+#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16))
+#else
+#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1))
+#endif
+    __m256 c1 = _mm256_setzero_ps();
+    __m256 c2 = _mm256_setzero_ps();
+    __m256 c3 = _mm256_setzero_ps();
+    __m256 c4 = _mm256_setzero_ps();
+    for (; i + 32 <= n; i += 32) {
+        c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1);
+        c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2);
+        c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3);
+        c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4);
+    }
+    __m128 g;
+    c1 = _mm256_add_ps(_mm256_add_ps(c1, c3),
+                       _mm256_add_ps(c2, c4));
+    g = _mm_add_ps(_mm256_extractf128_ps(c1, 1),
+                   _mm256_castps256_ps128(c1));
+    g = _mm_add_ps(g, _mm_movehl_ps(g, g));
+    g = _mm_add_ss(g, _mm_movehdup_ps(g));
+    sumf += (ggml_float)_mm_cvtss_f32(g);
+
+#undef LOAD
+#endif
+
+    for (; i < n; ++i) {
+        sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
+                             GGML_BF16_TO_FP32(y[i]));
+    }
+    *s = sumf;
+}
+
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+
+    ggml_float sumf = 0.0;
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    GGML_F16_VEC_REDUCE(sumf, sum);
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
+    }
+#endif
+
+    *s = sumf;
+}
+
+// compute GGML_VEC_DOT_UNROLL dot products at once
+// xs - x row stride in bytes
+inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
+    ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 };
+
+    ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        x[i] = (ggml_fp16_t *) ((char *) xv + i*xs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+
+            for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+                ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
+
+                sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
+            }
+        }
+    }
+
+    // reduce sum0..sum3 to sum0
+    for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
+        GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#else
+    for (int i = 0; i < n; ++i) {
+        for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
+        }
+    }
+#endif
+
+    for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) {
+        s[i] = sumf[i];
+    }
+}
+
+inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ax[GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] += x[i]*v;
+    }
+#endif
+}
+
+inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ax[GGML_F16_ARR];
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
+    }
+#endif
+}
+
+// xs and vs are byte strides of x and v
+inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
+
+    const float * restrict x[GGML_VEC_MAD_UNROLL];
+    const float * restrict v[GGML_VEC_MAD_UNROLL];
+
+    for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) {
+        x[i] = (const float *) ((const char *) xv + i*xs);
+        v[i] = (const float *) ((const char *) vv + i*vs);
+    }
+
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL];
+
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        vx[k] = GGML_F32_VEC_SET1(v[k][0]);
+    }
+
+    GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR];
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+
+            for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+                ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR);
+                ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]);
+            }
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = np; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#else
+    // scalar
+    for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) {
+        for (int i = 0; i < n; ++i) {
+            y[i] += x[k][i]*v[k][0];
+        }
+    }
+#endif
+}
+
+//inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
+inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
+#if defined(GGML_USE_ACCELERATE)
+    vDSP_vsmul(y, 1, &v, y, 1, n);
+#elif defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F32_STEP - 1));
+
+    GGML_F32_VEC vx = GGML_F32_VEC_SET1(v);
+
+    GGML_F32_VEC ay[GGML_F32_ARR];
+
+    for (int i = 0; i < np; i += GGML_F32_STEP) {
+        for (int j = 0; j < GGML_F32_ARR; j++) {
+            ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR);
+            ay[j] = GGML_F32_VEC_MUL(ay[j], vx);
+
+            GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] *= v;
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] *= v;
+    }
+#endif
+}
+
+inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
+#if defined(GGML_SIMD)
+    const int np = (n & ~(GGML_F16_STEP - 1));
+
+    GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
+
+    GGML_F16_VEC ay[GGML_F16_ARR];
+
+    for (int i = 0; i < np; i += GGML_F16_STEP) {
+        for (int j = 0; j < GGML_F16_ARR; j++) {
+            ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
+            ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
+
+            GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
+        }
+    }
+
+    // leftovers
+    for (int i = np; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
+    }
+#endif
+}
+
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
+inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
+inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
+inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
+inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
+inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
+inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
+inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
+inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
+inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
+inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
+inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
+// TODO: optimize performance
+inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
+inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); }
+
+static const float GELU_COEF_A     = 0.044715f;
+static const float GELU_QUICK_COEF = -1.702f;
+static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+inline static float ggml_gelu_f32(float x) {
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_table_gelu_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        if (x[i] <= -10.0f) {
+            y[i] = 0.0f;
+        } else if (x[i] >= 10.0f) {
+            y[i] = x[i];
+        } else {
+            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+            memcpy(&t, &fp16, sizeof(uint16_t));
+            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]);
+        }
+    }
+}
+#else
+inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_f32(x[i]);
+    }
+}
+#endif
+
+inline static float ggml_gelu_quick_f32(float x) {
+    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
+}
+
+//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+//    const uint16_t * i16 = (const uint16_t *) x;
+//    for (int i = 0; i < n; ++i) {
+//        y[i] = ggml_table_gelu_quick_f16[i16[i]];
+//    }
+//}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_quick_f32(x[i]);
+    }
+}
+#endif
+
+// Sigmoid Linear Unit (SiLU) function
+inline static float ggml_silu_f32(float x) {
+    return x/(1.0f + expf(-x));
+}
+
+#if __FINITE_MATH_ONLY__
+#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
+#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
+#endif
+
+#if defined(__ARM_NEON) && defined(__aarch64__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static float32x4_t ggml_v_expf(float32x4_t x) {
+    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
+    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
+    const float32x4_t n = vsubq_f32(z, r);
+    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
+                                    vdupq_n_f32(0x1.7f7d1cp-20f));
+    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
+    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
+    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
+    const float32x4_t u = vmulq_f32(b, b);
+    const float32x4_t j = vfmaq_f32(
+        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
+        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
+                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
+    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
+        return vfmaq_f32(k, j, k);
+    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
+    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
+    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
+    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
+                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static float32x4_t ggml_v_silu(float32x4_t x) {
+    const float32x4_t one = vdupq_n_f32(1.0f);
+    const float32x4_t zero = vdupq_n_f32(0.0f);
+    const float32x4_t neg_x = vsubq_f32(zero, x);
+    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
+    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
+    return vdivq_f32(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX512F__) && defined(__AVX512DQ__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m512 ggml_v_expf(__m512 x) {
+  const __m512 r = _mm512_set1_ps(0x1.8p23f);
+  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
+  const __m512 n = _mm512_sub_ps(z, r);
+  const __m512 b =
+      _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+                       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
+  const __mmask16 d =
+      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
+  const __m512 u = _mm512_mul_ps(b, b);
+  const __m512 j = _mm512_fmadd_ps(
+      _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+                                      _mm512_set1_ps(0x1.573e2ep-5f)),
+                      u,
+                      _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+                                      _mm512_set1_ps(0x1.fffdb6p-2f))),
+      u,
+      _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+  const __m512 res = _mm512_scalef_ps(j, n);
+  if (_mm512_kortestz(d, d))
+    return res;
+  const __m512 zero = _mm512_setzero_ps();
+  const __m512 alt = _mm512_mask_blend_ps(
+      _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+  return _mm512_mask_blend_ps(d, res, alt);
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m512 ggml_v_silu(__m512 x) {
+    const __m512 one = _mm512_set1_ps(1);
+    const __m512 zero = _mm512_setzero_ps();
+    const __m512 neg_x = _mm512_sub_ps(zero, x);
+    const __m512 exp_neg_x = ggml_v_expf(neg_x);
+    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
+    return _mm512_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__AVX2__) && defined(__FMA__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m256 ggml_v_expf(__m256 x) {
+  const __m256 r = _mm256_set1_ps(0x1.8p23f);
+  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
+  const __m256 n = _mm256_sub_ps(z, r);
+  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
+                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
+  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
+  const __m256 k = _mm256_castsi256_ps(
+      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
+  const __m256i c = _mm256_castps_si256(
+      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+                    _mm256_set1_ps(126), _CMP_GT_OQ));
+  const __m256 u = _mm256_mul_ps(b, b);
+  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
+                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
+                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
+                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
+                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
+  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
+    return _mm256_fmadd_ps(j, k, k);
+  const __m256i g = _mm256_and_si256(
+      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
+      _mm256_set1_epi32(0x82000000u));
+  const __m256 s1 =
+      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
+  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
+  const __m256i d = _mm256_castps_si256(
+      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+                    _mm256_set1_ps(192), _CMP_GT_OQ));
+  return _mm256_or_ps(
+      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
+      _mm256_andnot_ps(
+          _mm256_castsi256_ps(d),
+          _mm256_or_ps(
+              _mm256_and_ps(_mm256_castsi256_ps(c),
+                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
+              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m256 ggml_v_silu(__m256 x) {
+    const __m256 one = _mm256_set1_ps(1);
+    const __m256 zero = _mm256_setzero_ps();
+    const __m256 neg_x = _mm256_sub_ps(zero, x);
+    const __m256 exp_neg_x = ggml_v_expf(neg_x);
+    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
+    return _mm256_div_ps(x, one_plus_exp_neg_x);
+}
+
+#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
+
+#if defined(__FMA__)
+#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
+#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
+#else
+#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
+#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
+#endif
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+inline static __m128 ggml_v_expf(__m128 x) {
+    const __m128 r = _mm_set1_ps(0x1.8p23f);
+    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
+    const __m128 n = _mm_sub_ps(z, r);
+    const __m128 b =
+        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
+    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
+    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
+    const __m128i c =
+        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
+    const __m128 u = _mm_mul_ps(b, b);
+    const __m128 j =
+        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
+                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
+                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
+    if (!_mm_movemask_epi8(c))
+        return MADD128(j, k, k);
+    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
+                                    _mm_set1_epi32(0x82000000u));
+    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
+    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
+    const __m128i d =
+        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
+    return _mm_or_ps(
+        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
+        _mm_andnot_ps(_mm_castsi128_ps(d),
+                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
+                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
+}
+
+// computes silu x/(1+exp(-x)) in single precision vector
+inline static __m128 ggml_v_silu(__m128 x) {
+    const __m128 one = _mm_set1_ps(1);
+    const __m128 zero = _mm_setzero_ps();
+    const __m128 neg_x = _mm_sub_ps(zero, x);
+    const __m128 exp_neg_x = ggml_v_expf(neg_x);
+    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
+    return _mm_div_ps(x, one_plus_exp_neg_x);
+}
+
+#endif // __ARM_NEON / __AVX2__ / __SSE2__
+
+static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
+    int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
+    }
+#endif
+    for (; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]);
+    }
+}
+
+static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
+    int i = 0;
+    ggml_float sum = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
+                                               _mm512_set1_ps(max)));
+        _mm512_storeu_ps(y + i, val);
+        sum += (ggml_float)_mm512_reduce_add_ps(val);
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
+                                               _mm256_set1_ps(max)));
+        _mm256_storeu_ps(y + i, val);
+        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
+                                 _mm256_castps256_ps128(val));
+        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
+        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
+        sum += (ggml_float)_mm_cvtss_f32(val2);
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
+                                            _mm_set1_ps(max)));
+        _mm_storeu_ps(y + i, val);
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
+        val = _mm_add_ss(val, _mm_movehdup_ps(val));
+#else
+        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
+        val = _mm_add_ps(val, tmp);
+        tmp = _mm_movehl_ps(tmp, val);
+        val = _mm_add_ss(val, tmp);
+#endif
+        sum += (ggml_float)_mm_cvtss_f32(val);
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
+                                                vdupq_n_f32(max)));
+        vst1q_f32(y + i, val);
+        sum += (ggml_float)vaddvq_f32(val);
+    }
+#endif
+    for (; i < n; ++i) {
+        float val = expf(x[i] - max);
+        sum += (ggml_float)val;
+        y[i] = val;
+    }
+    return sum;
+}
+
+static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) {
+    // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i)
+
+    int i = 0;
+    ggml_float sum = 0;
+    for (; i < n; ++i) {
+        float val = x[i] - max;
+        y[i] = val;
+        sum += (ggml_float)expf(val);
+    }
+    return sum = (ggml_float)logf(sum);
+}
+
+inline static float ggml_silu_backward_f32(float x, float dy) {
+    const float s = 1.0f/(1.0f + expf(-x));
+    return dy*s*(1.0f + x*(1.0f - s));
+}
+
+inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
+    for (int i = 0; i < n; ++i) {
+        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
+    }
+}
+
+inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = sum;
+#else
+    vDSP_sve(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) {
+    ggml_float sum = 0.0;
+    for (int i = 0; i < n; ++i) {
+        sum += (ggml_float)x[i];
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_FP16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) {
+    float sum = 0.0f;
+    for (int i = 0; i < n; ++i) {
+        sum += GGML_BF16_TO_FP32(x[i]);
+    }
+    *s = sum;
+}
+
+inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
+#ifndef GGML_USE_ACCELERATE
+    float max = -INFINITY;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    *s = max;
+#else
+    vDSP_maxv(x, 1, s, n);
+#endif
+}
+
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+    ggml_vec_norm_f32(n, s, x);
+    *s = 1.f/(*s);
+}
+
+inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) {
+    float max = -INFINITY;
+    int idx = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+        if (max == x[i]) { idx = i; }
+    }
+    *s = idx;
+}
+
+// Helpers for polling loops
+#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) )
+static inline void ggml_thread_cpu_relax(void) {
+    __asm__ volatile("yield" ::: "memory");
+}
+#elif defined(__x86_64__)
+static inline void ggml_thread_cpu_relax(void) {
+    _mm_pause();
+}
+#else
+static inline void ggml_thread_cpu_relax(void) {;}
+#endif
+
+//
+// NUMA support
+//
+
+#define GGML_NUMA_MAX_NODES 8
+#define GGML_NUMA_MAX_CPUS 512
+
+struct ggml_numa_node {
+    uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node
+    uint32_t n_cpus;
+};
+
+struct ggml_numa_nodes {
+    enum ggml_numa_strategy numa_strategy;
+    struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES];
+    uint32_t n_nodes;
+    uint32_t total_cpus; // hardware threads on system
+    uint32_t current_node; // node on which main process is execting
+#if defined(__gnu_linux__)
+    cpu_set_t cpuset; // cpuset from numactl
+#else
+    uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype
+#endif
+};
+
+//
+// ggml state
+//
+
+struct ggml_state {
+    struct ggml_numa_nodes numa;
+};
+
+static struct ggml_state g_state = {0};
+
+void ggml_barrier(struct ggml_threadpool * tp) {
+    int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed);
+    if (n_threads == 1) {
+        return;
+    }
+
+#ifdef GGML_USE_OPENMP
+    #pragma omp barrier
+#else
+    int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed);
+
+    // enter barrier (full seq-cst fence)
+    int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst);
+
+    if (n_barrier == (n_threads - 1)) {
+        // last thread
+        atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed);
+
+        // exit barrier (fill seq-cst fence)
+        atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst);
+        return;
+    }
+
+    // wait for other threads
+    while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) {
+        ggml_thread_cpu_relax();
+    }
+
+    // exit barrier (full seq-cst fence)
+    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+    #ifdef GGML_TSAN_ENABLED
+    atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst);
+    #else
+    atomic_thread_fence(memory_order_seq_cst);
+    #endif
+#endif
+}
+
+#if defined(__gnu_linux__)
+static cpu_set_t ggml_get_numa_affinity(void) {
+    cpu_set_t cpuset;
+    pthread_t thread;
+    thread = pthread_self();
+    CPU_ZERO(&cpuset);
+    pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset);
+    return cpuset;
+}
+#else
+static uint32_t ggml_get_numa_affinity(void) {
+    return 0; // no NUMA support
+}
+#endif
+
+void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
+    if (g_state.numa.n_nodes > 0) {
+        fprintf(stderr, "ggml_numa_init: NUMA already initialized\n");
+
+        return;
+    }
+
+#if defined(__gnu_linux__)
+    struct stat st;
+    char path[256];
+    int rv;
+
+    // set numa scheme
+    g_state.numa.numa_strategy = numa_flag;
+
+    GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy);
+
+    g_state.numa.cpuset = ggml_get_numa_affinity();
+
+    // enumerate nodes
+    while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.n_nodes;
+    }
+
+    // enumerate CPUs
+    while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) {
+        rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus);
+        GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+        if (stat(path, &st) != 0) { break; }
+        ++g_state.numa.total_cpus;
+    }
+
+    GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus);
+
+    // figure out which node we're on
+    uint current_cpu;
+    int getcpu_ret = 0;
+#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__)
+    getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node);
+#else
+    // old glibc doesn't have a wrapper for this call. Fall back on direct syscall
+#   if !defined(SYS_getcpu) && defined(SYS_get_cpu)
+#       define SYS_getcpu SYS_get_cpu // some older glibc versions use this name
+#   endif
+    getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node);
+#endif
+
+    if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) {
+        g_state.numa.n_nodes = 0;
+        return;
+    }
+
+    GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu);
+
+    for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) {
+        struct ggml_numa_node * node = &g_state.numa.nodes[n];
+        GGML_PRINT_DEBUG("CPUs on node %u:", n);
+        node->n_cpus = 0;
+        for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) {
+            rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c);
+            GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path));
+            if (stat(path, &st) == 0) {
+                node->cpus[node->n_cpus++] = c;
+                GGML_PRINT_DEBUG(" %u", c);
+            }
+        }
+        GGML_PRINT_DEBUG("\n");
+    }
+
+    if (ggml_is_numa()) {
+        FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r");
+        if (fptr != NULL) {
+            char buf[42];
+            if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
+                GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
+            }
+            fclose(fptr);
+        }
+    }
+#else
+    UNUSED(numa_flag);
+    // TODO
+#endif
+}
+
+bool ggml_is_numa(void) {
+    return g_state.numa.n_nodes > 1;
+}
+
+#if defined(__ARM_ARCH)
+
+#if defined(__linux__) && defined(__aarch64__)
+#include 
+#elif defined(__APPLE__)
+#include 
+#endif
+
+#if !defined(HWCAP2_I8MM)
+#define HWCAP2_I8MM (1 << 13)
+#endif
+
+static void ggml_init_arm_arch_features(void) {
+#if defined(__linux__) && defined(__aarch64__)
+    uint32_t hwcap = getauxval(AT_HWCAP);
+    uint32_t hwcap2 = getauxval(AT_HWCAP2);
+
+    ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+    ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
+    ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
+    ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
+
+#if defined(__ARM_FEATURE_SVE)
+    ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
+#endif
+#elif defined(__APPLE__)
+    int oldp = 0;
+    size_t size = sizeof(oldp);
+    if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_neon = oldp;
+
+    if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_dotprod = oldp;
+
+    if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_i8mm = oldp;
+
+    ggml_arm_arch_features.has_sve = 0;
+    ggml_arm_arch_features.sve_cnt = 0;
+#else
+// Run-time CPU feature detection not implemented for this platform, fallback to compile time
+#if defined(__ARM_NEON)
+    ggml_arm_arch_features.has_neon = 1;
+#else
+    ggml_arm_arch_features.has_neon = 0;
+#endif
+
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    ggml_arm_arch_features.has_i8mm = 1;
+#else
+    ggml_arm_arch_features.has_i8mm = 0;
+#endif
+
+#if defined(__ARM_FEATURE_SVE)
+    ggml_arm_arch_features.has_sve = 1;
+    ggml_arm_arch_features.sve_cnt = 16;
+#else
+    ggml_arm_arch_features.has_sve = 0;
+    ggml_arm_arch_features.sve_cnt = 0;
+#endif
+#endif
+}
+#endif
+
+struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
+    GGML_ASSERT(!ggml_get_no_alloc(ctx));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
+
+    ggml_set_i32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
+    GGML_ASSERT(!ggml_get_no_alloc(ctx));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+
+    ggml_set_f32(result, value);
+
+    return result;
+}
+
+struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    return tensor;
+}
+
+struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
+    const int n     = ggml_nrows(tensor);
+    const int nc    = tensor->ne[0];
+    const size_t n1 = tensor->nb[1];
+
+    char * const data = tensor->data;
+
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                assert(tensor->nb[0] == sizeof(int8_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I16:
+            {
+                assert(tensor->nb[0] == sizeof(int16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_I32:
+            {
+                assert(tensor->nb[0] == sizeof(int32_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value);
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_fp16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                assert(tensor->nb[0] == sizeof(ggml_bf16_t));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value));
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                assert(tensor->nb[0] == sizeof(float));
+                for (int i = 0; i < n; i++) {
+                    ggml_vec_set_f32(nc, (float *)(data + i*n1), value);
+                }
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    return tensor;
+}
+
+int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]);
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                return ((int8_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                return ((int16_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                return ((int32_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_BF16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
+                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                return ((float *)(tensor->data))[i];
+            }
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value);
+        return;
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
+                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                GGML_ASSERT(tensor->nb[0] == sizeof(float));
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            return ((int8_t *) data)[0];
+        case GGML_TYPE_I16:
+            return ((int16_t *) data)[0];
+        case GGML_TYPE_I32:
+            return ((int32_t *) data)[0];
+        case GGML_TYPE_F16:
+            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+        case GGML_TYPE_BF16:
+            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
+        case GGML_TYPE_F32:
+            return ((float *) data)[0];
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(data))[0] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]);
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                return ((int8_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I16:
+            {
+                return ((int16_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_I32:
+            {
+                return ((int32_t *)(tensor->data))[i];
+            }
+        case GGML_TYPE_F16:
+            {
+                return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_BF16:
+            {
+                return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
+            }
+        case GGML_TYPE_F32:
+            {
+                return ((float *)(tensor->data))[i];
+            }
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
+    if (!ggml_is_contiguous(tensor)) {
+        int64_t id[4] = { 0, 0, 0, 0 };
+        ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]);
+        ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value);
+        return;
+    }
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(tensor->data))[i] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(tensor->data))[i] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            return ((int8_t *) data)[0];
+        case GGML_TYPE_I16:
+            return ((int16_t *) data)[0];
+        case GGML_TYPE_I32:
+            return ((int32_t *) data)[0];
+        case GGML_TYPE_F16:
+            return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]);
+        case GGML_TYPE_BF16:
+            return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]);
+        case GGML_TYPE_F32:
+            return ((float *) data)[0];
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) {
+    void * data   = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3];
+    switch (tensor->type) {
+        case GGML_TYPE_I8:
+            {
+                ((int8_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I16:
+            {
+                ((int16_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ((int32_t *)(data))[0] = value;
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ((float *)(data))[0] = value;
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_compute_forward_dup
+
+static void ggml_compute_forward_dup_same_cont(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    const size_t nb0 = ggml_type_size(src0->type);
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by elements
+    const int ne = ggml_nelements(dst);
+    const int dr = (ne + nth - 1) / nth;
+    const int ie0 = dr * ith;
+    const int ie1 = MIN(ie0 + dr, ne);
+
+    if (ie0 < ie1) {
+        memcpy(
+            ((char *)  dst->data + ie0*nb0),
+            ((char *) src0->data + ie0*nb0),
+            (ie1 - ie0) * nb0);
+    }
+}
+
+static void ggml_compute_forward_dup_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+    if (ggml_is_contiguous(dst)) {
+        if (nb00 == sizeof(ggml_fp16_t)) {
+            if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+        return;
+    }
+
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_bf16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+
+    if (ggml_is_contiguous(dst)) {
+        if (nb00 == sizeof(ggml_bf16_t)) {
+            if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00]));
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+                float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]);
+                            }
+
+                            quantize_row_q(src0_f32, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr));
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+        return;
+    }
+
+    // dst counters
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_BF16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t));
+
+                        if (++i10 == ne00) {
+                            i10 = 0;
+                            if (++i11 == ne01) {
+                                i11 = 0;
+                                if (++i12 == ne02) {
+                                    i12 = 0;
+                                    if (++i13 == ne03) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr));
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
+        // copy by rows
+        const size_t rs = ne00*nb00;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        // TODO: simplify
+        if (nb00 == sizeof(float)) {
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                const size_t rs = ne00 * nb00;
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, rs);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+
+                size_t id = 0;
+                size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
+                char * dst_ptr = (char *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += rs * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            id += rs;
+                        }
+                        id += rs * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            if (dst->type == GGML_TYPE_F32) {
+                size_t id = 0;
+                float * dst_ptr = (float *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_F16) {
+                size_t id = 0;
+                ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else if (dst->type == GGML_TYPE_BF16) {
+                size_t id = 0;
+                ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr);
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
+            } else {
+                GGML_ABORT("fatal error"); // TODO: implement
+            }
+        }
+
+        return;
+    }
+
+    // dst counters
+
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        memcpy(dst_ptr, src0_ptr, sizeof(float));
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else if (dst->type == GGML_TYPE_BF16) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr);
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
+static void ggml_compute_forward_dup_bytes(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+    GGML_ASSERT(src0->type == dst->type);
+
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
+        ggml_compute_forward_dup_same_cont(params, dst);
+        return;
+    }
+
+    const size_t type_size = ggml_type_size(src0->type);
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (src0->type == dst->type &&
+        ne00 == ne0 &&
+        nb00 == type_size && nb0 == type_size) {
+        // copy by rows
+        const size_t rs = ne00 * type_size;
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    memcpy(
+                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+                        rs);
+                }
+            }
+        }
+        return;
+    }
+
+    if (ggml_is_contiguous(dst)) {
+        size_t id = 0;
+        char * dst_ptr = (char *) dst->data;
+        const size_t rs = ne00 * type_size;
+
+        if (nb00 == type_size) {
+            // src0 is contigous on first dimension, copy by rows
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    id += rs * ir0;
+                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+                        memcpy(dst_ptr + id, src0_ptr, rs);
+                        id += rs;
+                    }
+                    id += rs * (ne01 - ir1);
+                }
+            }
+        } else {
+            //printf("%s: this is not optimal - fix me\n", __func__);
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    id += rs * ir0;
+                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
+                            memcpy(dst_ptr + id, src0_ptr, type_size);
+
+                            id += type_size;
+                        }
+                    }
+                    id += rs * (ne01 - ir1);
+                }
+            }
+        }
+
+        return;
+    }
+
+    // dst counters
+
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            i10 += ne00 * ir0;
+            while (i10 >= ne0) {
+                i10 -= ne0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+            for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                    memcpy(dst_ptr, src0_ptr, type_size);
+
+                    if (++i10 == ne0) {
+                        i10 = 0;
+                        if (++i11 == ne1) {
+                            i11 = 0;
+                            if (++i12 == ne2) {
+                                i12 = 0;
+                                if (++i13 == ne3) {
+                                    i13 = 0;
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+            i10 += ne00 * (ne01 - ir1);
+            while (i10 >= ne0) {
+                i10 -= ne0;
+                if (++i11 == ne1) {
+                    i11 = 0;
+                    if (++i12 == ne2) {
+                        i12 = 0;
+                        if (++i13 == ne3) {
+                            i13 = 0;
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_dup(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (src0->type == dst->type) {
+        ggml_compute_forward_dup_bytes(params, dst);
+        return;
+    }
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_dup_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_dup_bf16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_dup_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add
+
+static void ggml_compute_forward_add_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_f16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT( nb0 == sizeof(float));
+    }
+    else {
+        GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+        GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    }
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        if (dst->type == GGML_TYPE_F16) {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+                }
+            }
+        } else {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+                }
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_bf16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT( nb0 == sizeof(float));
+    }
+    else {
+        GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+        GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    }
+
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        if (dst->type == GGML_TYPE_BF16) {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
+                }
+            }
+        } else {
+            for (int ir = ir0; ir < ir1; ++ir) {
+                // src0, src1 and dst are same shape => same indices
+                const int i3 = ir/(ne2*ne1);
+                const int i2 = (ir - i3*ne2*ne1)/ne1;
+                const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                float *       dst_ptr  = (float *)       ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+                ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                float *       src1_ptr = (float *)       ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+                for (int i = 0; i < ne0; i++) {
+                    dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
+                }
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_f16_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(ggml_fp16_t)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+            ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+            for (int i = 0; i < ne0; i++) {
+                dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i]));
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_bf16_bf16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(ggml_bf16_t)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src0, src1 and dst are same shape => same indices
+            const int i3 = ir/(ne2*ne1);
+            const int i2 = (ir - i3*ne2*ne1)/ne1;
+            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+            ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1);
+            ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+            ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
+
+            for (int i = 0; i < ne0; i++) {
+                dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i]));
+            }
+        }
+    }
+    else {
+        // src1 is not contiguous
+        GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_compute_forward_add_q_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+    const enum ggml_type dtype = dst->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i03 = ir/(ne02*ne01);
+        const int i02 = (ir - i03*ne02*ne01)/ne01;
+        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        // src1 and dst are same shape as src0 => same indices
+        const int i13 = i03;
+        const int i12 = i02;
+        const int i11 = i01;
+
+        const int i3 = i03;
+        const int i2 = i02;
+        const int i1 = i01;
+
+        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
+        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
+
+        assert(ne00 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne00);
+        // add src1
+        ggml_vec_acc_f32(ne00, wdata, src1_row);
+        // quantize row to dst
+        if (quantize_row_q != NULL) {
+            quantize_row_q(wdata, dst_row, ne00);
+        } else {
+            memcpy(dst_row, wdata, ne0*nb0);
+        }
+    }
+}
+
+static void ggml_compute_forward_add(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add_f16_f16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_f16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                if (src1->type == GGML_TYPE_BF16) {
+                    ggml_compute_forward_add_bf16_bf16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add_bf16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_add_q_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add1
+
+static void ggml_compute_forward_add1_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+#ifdef GGML_USE_ACCELERATE
+        UNUSED(ggml_vec_add1_f32);
+
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
+                (float *) ((char *) src1->data), 0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
+                ne0);
+#else
+        ggml_vec_add1_f32(ne0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+               *(float *) src1->data);
+#endif
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_f16_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_q_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
+
+    // we don't support permuted src0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ggml_is_quantized(src0->type));
+    GGML_ASSERT(dst->type == src0->type);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
+        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
+
+        assert(ne0 % 32 == 0);
+
+        // unquantize row from src0 to temp buffer
+        dequantize_row_q(src0_row, wdata, ne0);
+        // add src1
+        ggml_vec_acc1_f32(ne0, wdata, v);
+        // quantize row to dst
+        quantize_row_q(wdata, dst_row, ne0);
+    }
+}
+
+static void ggml_compute_forward_add1_bf16_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = *(float *) src1->data;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1_bf16_bf16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_scalar(src1));
+
+    // scalar to add
+    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
+    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
+    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
+
+    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are same shape => same indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
+        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+        for (int i = 0; i < ne0; i++) {
+            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
+        }
+    }
+}
+
+static void ggml_compute_forward_add1(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add1_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                if (src1->type == GGML_TYPE_F16) {
+                    ggml_compute_forward_add1_f16_f16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_f16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                if (src1->type == GGML_TYPE_BF16) {
+                    ggml_compute_forward_add1_bf16_bf16(params, dst);
+                }
+                else if (src1->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_add1_bf16_f32(params, dst);
+                }
+                else {
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_add1_q_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_acc
+
+static void ggml_compute_forward_acc_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during acc
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during acc
+    const size_t nb0 = ggml_element_size(src0);
+
+    const size_t nb00 = nb0;
+    const size_t nb01 = nb1;
+    const size_t nb02 = nb2;
+    const size_t nb03 = nb3;
+
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
+    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+#ifdef GGML_USE_ACCELERATE
+        vDSP_vadd(
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
+#else
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+#endif
+    }
+}
+
+static void ggml_compute_forward_acc(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_acc_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sub
+
+static void ggml_compute_forward_sub_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    if (nb10 == sizeof(float)) {
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int ir = ir0; ir < ir1; ++ir) {
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_sub(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sub_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mul
+
+static void ggml_compute_forward_mul_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0 ; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                UNUSED(ggml_vec_mul_f32);
+
+                vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mul(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now");
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mul_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_div
+
+static void ggml_compute_forward_div_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (nb10 == sizeof(float)) {
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+            const int64_t nr0 = ne00 / ne10;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
+
+            for (int64_t r = 0; r < nr0; ++r) {
+#ifdef GGML_USE_ACCELERATE
+                UNUSED(ggml_vec_div_f32);
+
+                vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
+#else
+                ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
+#endif
+            }
+        }
+    } else {
+        // src1 is not contiguous
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; ++i0) {
+                const int64_t i10 = i0 % ne10;
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
+
+                dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_div(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_div_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sqr
+
+static void ggml_compute_forward_sqr_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n     = ggml_nrows(src0);
+    const int nc    = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqr_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sqr(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqr_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sqrt
+
+static void ggml_compute_forward_sqrt_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert( dst->nb[0] == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sqrt_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sqrt(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sqrt_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_log
+
+static void ggml_compute_forward_log_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_log_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_log(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_log_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sin
+
+static void ggml_compute_forward_sin_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sin_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sin(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sin_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cos
+
+static void ggml_compute_forward_cos_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_cos_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_cos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sum
+
+static void ggml_compute_forward_sum_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+    assert(src0->nb[0] == sizeof(float));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    ggml_float sum     = 0;
+    ggml_float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32_ggf(ne00,
+                        &row_sum,
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((float *) dst->data)[0] = sum;
+}
+
+static void ggml_compute_forward_sum_f16(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+
+    assert(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f16_ggf(ne00,
+                    &row_sum,
+                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum);
+}
+
+static void ggml_compute_forward_sum_bf16(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_scalar(dst));
+
+    assert(src0->nb[0] == sizeof(ggml_bf16_t));
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
+
+    float sum = 0;
+    float row_sum = 0;
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_bf16_ggf(ne00,
+                    &row_sum,
+                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
+                sum += row_sum;
+            }
+        }
+    }
+    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
+}
+
+static void ggml_compute_forward_sum(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_sum_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_sum_bf16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sum_rows
+
+static void ggml_compute_forward_sum_rows_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne0 == 1);
+    GGML_ASSERT(ne1 == ne01);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    for (int64_t i3 = 0; i3 < ne03; i3++) {
+        for (int64_t i2 = 0; i2 < ne02; i2++) {
+            for (int64_t i1 = 0; i1 < ne01; i1++) {
+                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
+                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
+                float row_sum = 0;
+                ggml_vec_sum_f32(ne00, &row_sum, src_row);
+                dst_row[0] = row_sum;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_sum_rows(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sum_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mean
+
+static void ggml_compute_forward_mean_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    assert(ne0 == 1);
+    assert(ne1 == ne01);
+    assert(ne2 == ne02);
+    assert(ne3 == ne03);
+
+    UNUSED(ne0);
+    UNUSED(ne1);
+    UNUSED(ne2);
+    UNUSED(ne3);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                ggml_vec_sum_f32(ne00,
+                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
+                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
+
+                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mean(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_mean_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_argmax
+
+static void ggml_compute_forward_argmax_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(src0->nb[0] == sizeof(float));
+    assert(dst->nb[0] == sizeof(float));
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    const size_t nb01 = src0->nb[1];
+    const size_t nb0 = dst->nb[0];
+
+    for (int64_t i1 = 0; i1 < ne01; i1++) {
+        float * src = (float *) ((char *) src0->data + i1*nb01);
+        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
+        int v = 0;
+        ggml_vec_argmax_f32(ne00, &v, src);
+        dst_[0] = v;
+    }
+}
+
+static void ggml_compute_forward_argmax(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argmax_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_count_equal
+
+static void ggml_compute_forward_count_equal_i32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    GGML_ASSERT(src0->type == GGML_TYPE_I32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_I64);
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    int64_t * sums = (int64_t *) params->wdata;
+    int64_t sum_thread = 0;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 =  ir                        / (ne02*ne01);
+        const int64_t i02 = (ir - i03*ne03)            /       ne01;
+        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
+
+        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
+        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
+
+        for (int64_t i00 = 0; i00 < ne00; ++i00) {
+            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
+            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
+
+            sum_thread += val0 == val1;
+        }
+    }
+    if (ith != 0) {
+        sums[ith] = sum_thread;
+    }
+    ggml_barrier(params->threadpool);
+
+    if (ith != 0) {
+        return;
+    }
+
+    for (int ith_other = 1; ith_other < nth; ++ith_other) {
+        sum_thread += sums[ith_other];
+    }
+    *((int64_t *) dst->data) = sum_thread;
+}
+
+static void ggml_compute_forward_count_equal(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_count_equal_i32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_repeat
+
+static void ggml_compute_forward_repeat_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_vec_cpy_f32(ne00,
+                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
+                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne0/ne00);
+    const int nr1 = (int)(ne1/ne01);
+    const int nr2 = (int)(ne2/ne02);
+    const int nr3 = (int)(ne3/ne03);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3;  i3++) {
+        for                     (int k3 = 0; k3 < ne03; k3++) {
+            for                 (int i2 = 0; i2 < nr2;  i2++) {
+                for             (int k2 = 0; k2 < ne02; k2++) {
+                    for         (int i1 = 0; i1 < nr1;  i1++) {
+                        for     (int k1 = 0; k1 < ne01; k1++) {
+                            for (int i0 = 0; i0 < nr0;  i0++) {
+                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
+                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
+                                // ggml_vec_cpy_f16(ne00, y, x)
+                                for (int i = 0; i < ne00; ++i) {
+                                    y[i]  = x[i];
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_I16:
+            {
+                ggml_compute_forward_repeat_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_repeat_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_repeat_back
+
+static void ggml_compute_forward_repeat_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // guaranteed to be an integer due to the check in ggml_can_repeat
+    const int nr0 = (int)(ne00/ne0);
+    const int nr1 = (int)(ne01/ne1);
+    const int nr2 = (int)(ne02/ne2);
+    const int nr3 = (int)(ne03/ne3);
+
+    // TODO: support for transposed / permuted tensors
+    GGML_ASSERT(nb0  == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    if (ggml_is_contiguous(dst)) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    } else {
+        for         (int k3 = 0; k3 < ne3; k3++) {
+            for     (int k2 = 0; k2 < ne2; k2++) {
+                for (int k1 = 0; k1 < ne1; k1++) {
+                    ggml_vec_set_f32(ne0,
+                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
+                        0);
+                }
+            }
+        }
+    }
+
+    // TODO: maybe this is not optimal?
+    for                         (int i3 = 0; i3 < nr3; i3++) {
+        for                     (int k3 = 0; k3 < ne3; k3++) {
+            for                 (int i2 = 0; i2 < nr2; i2++) {
+                for             (int k2 = 0; k2 < ne2; k2++) {
+                    for         (int i1 = 0; i1 < nr1; i1++) {
+                        for     (int k1 = 0; k1 < ne1; k1++) {
+                            for (int i0 = 0; i0 < nr0; i0++) {
+                                ggml_vec_acc_f32(ne0,
+                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
+                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_repeat_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_repeat_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_concat
+
+static void ggml_compute_forward_concat_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int32_t dim = ggml_get_op_params_i32(dst, 0);
+
+    GGML_ASSERT(dim >= 0 && dim < 4);
+
+    int64_t o[4] = {0, 0, 0, 0};
+    o[dim] = src0->ne[dim];
+
+    const float * x;
+
+    // TODO: smarter multi-theading
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = ith; i2 < ne2; i2 += nth) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                for (int i0 = 0; i0 < ne0; i0++) {
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
+                    } else {
+                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
+                    }
+
+                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_concat(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_concat_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_abs
+
+static void ggml_compute_forward_abs_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_abs_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_abs(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_abs_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sgn
+
+static void ggml_compute_forward_sgn_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sgn_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sgn(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sgn_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_neg
+
+static void ggml_compute_forward_neg_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_neg_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_neg(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_neg_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_step
+
+static void ggml_compute_forward_step_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_step_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_step(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_step_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_tanh
+
+static void ggml_compute_forward_tanh_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_tanh_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_tanh(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_tanh_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_elu
+
+static void ggml_compute_forward_elu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_elu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_elu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_elu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_relu
+
+static void ggml_compute_forward_relu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_relu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_relu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_sigmoid
+
+static void ggml_compute_forward_sigmoid_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sigmoid_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sigmoid(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sigmoid_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_gelu
+
+static void ggml_compute_forward_gelu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_gelu_quick
+
+static void ggml_compute_forward_gelu_quick_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_quick_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_quick(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_quick_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_silu
+
+static void ggml_compute_forward_silu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+// ggml_compute_forward_leaky_relu
+
+static void ggml_compute_forward_leaky_relu_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_leaky_relu_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
+    }
+}
+
+static void ggml_compute_forward_leaky_relu(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_leaky_relu_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_silu_back
+
+static void ggml_compute_forward_silu_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * grad = dst->src[1];
+
+    assert(ggml_is_contiguous_1(grad));
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+    assert(ggml_are_same_shape(src0, grad));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_silu_backward_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) grad->data + i1*(grad->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_silu_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_silu_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+static void ggml_compute_forward_hardswish_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_hardswish_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+static void ggml_compute_forward_hardswish(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_hardswish_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_hardsigmoid_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_hardsigmoid_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_hardsigmoid(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_hardsigmoid_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_exp_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_exp_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_exp(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_exp_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_norm
+
+static void ggml_compute_forward_norm_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps > 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)x[i00];
+                }
+
+                float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_float sum2 = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    float v = x[i00] - mean;
+                    y[i00] = v;
+                    sum2 += (ggml_float)(v*v);
+                }
+
+                float variance = sum2/ne00;
+                const float scale = 1.0f/sqrtf(variance + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_norm(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_group_rms_norm
+
+static void ggml_compute_forward_rms_norm_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps > 0.0f);
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+                ggml_float sum = 0.0;
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum += (ggml_float)(x[i00] * x[i00]);
+                }
+
+                const float mean = sum/ne00;
+
+                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                memcpy(y, x, ne00 * sizeof(float));
+                // for (int i00 = 0; i00 < ne00; i00++) {
+                //     y[i00] = x[i00];
+                // }
+
+                const float scale = 1.0f/sqrtf(mean + eps);
+
+                ggml_vec_scale_f32(ne00, y, scale);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_rms_norm_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    // TODO: optimize
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+                // src1 is same shape as src0 => same indices
+                const int64_t i11 = i01;
+                const int64_t i12 = i02;
+                const int64_t i13 = i03;
+
+                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+
+                ggml_float sum_xx  = 0.0;
+                ggml_float sum_xdz = 0.0;
+
+                for (int64_t i00 = 0; i00 < ne00; i00++) {
+                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
+                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
+                }
+
+                //const float mean     = (float)(sum_xx)/ne00;
+                const float mean_eps = (float)(sum_xx)/ne00 + eps;
+                const float sum_eps  = (float)(sum_xx) + eps*ne00;
+                //const float mean_xdz = (float)(sum_xdz)/ne00;
+                // we could cache rms from forward pass to improve performance.
+                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
+                //const float rms      = sqrtf(mean_eps);
+                const float rrms     = 1.0f / sqrtf(mean_eps);
+                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
+
+                {
+                    // z = rms_norm(x)
+                    //
+                    // rms_norm(src0) =
+                    //     scale(
+                    //         src0,
+                    //         div(
+                    //             1,
+                    //             sqrt(
+                    //                 add(
+                    //                     scale(
+                    //                         sum(
+                    //                             sqr(
+                    //                                 src0)),
+                    //                         (1.0/N)),
+                    //                     eps))));
+
+                    // postorder:
+                    // ## op    args         grad
+                    // 00 param src0         grad[#00]
+                    // 01 const 1
+                    // 02 sqr   (#00)        grad[#02]
+                    // 03 sum   (#02)        grad[#03]
+                    // 04 const 1/N
+                    // 05 scale (#03, #04)   grad[#05]
+                    // 06 const eps
+                    // 07 add   (#05, #06)   grad[#07]
+                    // 08 sqrt  (#07)        grad[#08]
+                    // 09 div   (#01,#08)    grad[#09]
+                    // 10 scale (#00,#09)    grad[#10]
+                    //
+                    // backward pass, given grad[#10]
+                    // #10: scale
+                    // grad[#00] += scale(grad[#10],#09)
+                    // grad[#09] += sum(mul(grad[#10],#00))
+                    // #09: div
+                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
+                    // #08: sqrt
+                    // grad[#07] += mul(grad[#08], div(0.5, #08))
+                    // #07: add
+                    // grad[#05] += grad[#07]
+                    // #05: scale
+                    // grad[#03] += scale(grad[#05],#04)
+                    // #03: sum
+                    // grad[#02] += repeat(grad[#03], #02)
+                    // #02:
+                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
+                    //
+                    // substitute and simplify:
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#02] = repeat(grad[#03], #02)
+                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
+                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
+                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
+                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
+                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
+                    // a = b*c + d*e
+                    // a = b*c*f/f + d*e*f/f
+                    // a = (b*c*f + d*e*f)*(1/f)
+                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
+                    // a = (b + d*e/c)*c
+                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
+                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
+                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
+                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
+                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
+                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
+                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
+                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                }
+                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
+                // post-order:
+                // dx := x
+                // dx := scale(dx,-mean_xdz/mean_eps)
+                // dx := add(dx, dz)
+                // dx := scale(dx, rrms)
+                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+                ggml_vec_cpy_f32  (ne00, dx, x);
+                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
+                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
+                ggml_vec_acc_f32  (ne00, dx, dz);
+                ggml_vec_scale_f32(ne00, dx, rrms);
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rms_norm_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rms_norm_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_group_norm
+
+static void ggml_compute_forward_group_norm_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    // TODO: optimize
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+    int n_channels = src0->ne[2];
+    int n_groups = dst->op_params[0];
+    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
+    for (int i = ith; i < n_groups; i += nth) {
+        int start = i * n_channels_per_group;
+        int end = start + n_channels_per_group;
+        if (end > n_channels) {
+            end = n_channels;
+        }
+        int step = end - start;
+
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            ggml_float sum = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    ggml_float sumr = 0.0;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        sumr += (ggml_float)x[i00];
+                    }
+                    sum += sumr;
+                }
+            }
+            const float mean = sum / (ne00 * ne01 * step);
+
+            ggml_float sum2 = 0.0;
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
+
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+
+                    ggml_float sumr = 0.0;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        float v = x[i00] - mean;
+                        y[i00] = v;
+                        sumr += (ggml_float)(v * v);
+                    }
+                    sum2 += sumr;
+                }
+            }
+            const float variance = sum2 / (ne00 * ne01 * step);
+            const float scale = 1.0f / sqrtf(variance + eps);
+
+            for (int64_t i02 = start; i02 < end; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
+                    ggml_vec_scale_f32(ne00, y, scale);
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_group_norm(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_group_norm_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_mul_mat
+
+static void ggml_compute_forward_mul_mat_one_chunk(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst,
+    const enum ggml_type type,
+    const int64_t num_rows_per_vec_dot,
+    const int64_t ir0_start,
+    const int64_t ir0_end,
+    const int64_t ir1_start,
+    const int64_t ir1_end) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    ggml_vec_dot_t const vec_dot      = type_traits_cpu[type].vec_dot;
+    enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
+
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
+    //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
+
+    // threads with no work simply yield (not sure if it helps)
+    if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
+        return;
+    }
+
+    const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+    const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+    assert(ne12 % ne02 == 0);
+    assert(ne13 % ne03 == 0);
+
+    // block-tiling attempt
+    const int64_t blck_0 = 16;
+    const int64_t blck_1 = 16;
+
+    const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
+
+    // attempt to reduce false-sharing (does not seem to make a difference)
+    // 16 * 2, accounting for mmla kernels
+    float tmp[32];
+
+    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
+                const int64_t i13 = (ir1 / (ne12 * ne1));
+                const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
+                const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
+
+                // broadcast src0 into src1
+                const int64_t i03 = i13 / r3;
+                const int64_t i02 = i12 / r2;
+
+                const int64_t i1 = i11;
+                const int64_t i2 = i12;
+                const int64_t i3 = i13;
+
+                const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
+
+                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                //       the original src1 data pointer, so we should index using the indices directly
+                // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                const char * src1_col = (const char*)wdata +
+                    (src1_cont || src1->type != vec_dot_type
+                        ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
+                        : (i11 * nb11 + i12 * nb12 + i13 * nb13));
+                float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
+
+                //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
+                //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+                //}
+
+                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
+                    vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
+                }
+
+                for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
+                    memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_mul_mat(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    enum ggml_type           const vec_dot_type         = type_traits_cpu[src0->type].vec_dot_type;
+    ggml_from_float_t        const from_float           = type_traits_cpu[vec_dot_type].from_float;
+    int64_t                  const vec_dot_num_rows     = type_traits_cpu[src0->type].nrows;
+
+    GGML_ASSERT(ne0 == ne01);
+    GGML_ASSERT(ne1 == ne11);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(src0->type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    // TODO: extract to "extra_op"
+#if GGML_USE_LLAMAFILE
+    // broadcast factors
+    const int64_t r2 = ne12 / ne02;
+    const int64_t r3 = ne13 / ne03;
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    if (src1_cont) {
+        for (int64_t i13 = 0; i13 < ne13; i13++)
+            for (int64_t i12 = 0; i12 < ne12; i12++)
+                if (!llamafile_sgemm(params,
+                                     ne01, ne11, ne00/ggml_blck_size(src0->type),
+                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+                                     nb01/ggml_type_size(src0->type),
+                                     (const char *)src1->data + i12*nb12 + i13*nb13,
+                                     nb11/ggml_type_size(src1->type),
+                                     (char *)dst->data + i12*nb2 + i13*nb3,
+                                     nb1/ggml_type_size(dst->type),
+                                     src0->type,
+                                     src1->type,
+                                     dst->type))
+                    goto UseGgmlGemm1;
+        return;
+    }
+UseGgmlGemm1:;
+#endif
+
+    if (src1->type != vec_dot_type) {
+        char * wdata = params->wdata;
+
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                                ne10);
+                }
+            }
+        }
+    }
+
+    if (ith == 0) {
+        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+        atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed);
+    }
+
+    ggml_barrier(params->threadpool);
+
+#if GGML_USE_LLAMAFILE
+    if (src1->type != vec_dot_type) {
+        const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+        for (int64_t i13 = 0; i13 < ne13; i13++)
+            for (int64_t i12 = 0; i12 < ne12; i12++)
+                if (!llamafile_sgemm(params,
+                                     ne01, ne11, ne00/ggml_blck_size(src0->type),
+                                     (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
+                                     nb01/ggml_type_size(src0->type),
+                                     (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
+                                     row_size/ggml_type_size(vec_dot_type),
+                                     (char *)dst->data + i12*nb2 + i13*nb3,
+                                     nb1/ggml_type_size(dst->type),
+                                     src0->type,
+                                     vec_dot_type,
+                                     dst->type))
+                    goto UseGgmlGemm2;
+        return;
+    }
+UseGgmlGemm2:;
+#endif
+
+    // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
+    const int64_t nr0 = ne0;
+
+    // This is the size of the rest of the dimensions of the result
+    const int64_t nr1 = ne1 * ne2 * ne3;
+
+    // Now select a reasonable chunk size.
+    int chunk_size = 16;
+
+    // We need to step up the size if it's small
+    if (nr0 == 1 || nr1 == 1) {
+        chunk_size = 64;
+    }
+
+    // distribute the work across the inner or outer loop based on which one is larger
+    // The number of chunks in the 0/1 dim.
+    // CEIL(nr0/chunk_size)
+    int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
+    int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
+
+    // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.
+    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggerganov/llama.cpp/pull/6915
+    //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
+    if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
+        // distribute the thread work across the inner or outer loop based on which one is larger
+        nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+        nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+    }
+
+    // The number of elements in each chunk
+    const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+    const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
+
+    // The first chunk comes from our thread_id, the rest will get auto-assigned.
+    int current_chunk = ith;
+
+    while (current_chunk < nchunk0 * nchunk1) {
+        const int64_t ith0 = current_chunk % nchunk0;
+        const int64_t ith1 = current_chunk / nchunk0;
+
+        const int64_t ir0_start = dr0 * ith0;
+        const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
+
+        const int64_t ir1_start = dr1 * ith1;
+        const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
+
+        // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
+        int64_t num_rows_per_vec_dot = vec_dot_num_rows;
+
+        // these checks are needed to avoid crossing dim1 boundaries
+        // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
+        if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
+            num_rows_per_vec_dot = 1;
+        }
+
+        ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
+
+        if (nth >= nchunk0 * nchunk1) {
+            break;
+        }
+
+        current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed);
+    }
+}
+
+// ggml_compute_forward_mul_mat_id
+
+static void ggml_compute_forward_mul_mat_id(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * ids = dst->src[2];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+
+    const bool src1_cont = ggml_is_contiguous(src1);
+
+    ggml_vec_dot_t    const vec_dot         = type_traits_cpu[type].vec_dot;
+    enum ggml_type    const vec_dot_type    = type_traits_cpu[type].vec_dot_type;
+    ggml_from_float_t const from_float      = type_traits_cpu[vec_dot_type].from_float;
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // row groups
+    const int n_ids = ids->ne[0]; // n_expert_used
+    const int n_as  = ne02;       // n_expert
+
+    char * wdata_src1_end = (src1->type == vec_dot_type) ?
+            (char *) params->wdata :
+            (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+
+    struct mmid_row_mapping {
+        int32_t i1;
+        int32_t i2;
+    };
+
+    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
+
+    if (src1->type != vec_dot_type) {
+        char * wdata = params->wdata;
+
+        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
+        const size_t nbw2 = nbw1*ne11;
+        const size_t nbw3 = nbw2*ne12;
+
+        assert(params->wsize >= ne13*nbw3);
+        GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
+                               ne10);
+                }
+            }
+        }
+    }
+
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
+    if (ith == 0) {
+        // initialize matrix_row_counts
+        memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
+
+        // group rows by src0 matrix
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+            for (int id = 0; id < n_ids; ++id) {
+                const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                assert(i02 >= 0 && i02 < n_as);
+
+                MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+                matrix_row_counts[i02] += 1;
+            }
+        }
+    }
+
+    ggml_barrier(params->threadpool);
+
+    // compute each matrix multiplication in sequence
+    for (int cur_a = 0; cur_a < n_as; ++cur_a) {
+        const int64_t cne1 = matrix_row_counts[cur_a];
+
+        if (cne1 == 0) {
+            continue;
+        }
+
+        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
+
+        const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const size_t row_size = ggml_row_size(vec_dot_type, ne10);
+
+        const int64_t nr0 = ne01; // src0 rows
+        const int64_t nr1 = cne1; // src1 rows
+
+        // distribute the thread work across the inner or outer loop based on which one is larger
+
+        const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
+        const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+
+        const int64_t ith0 = ith % nth0;
+        const int64_t ith1 = ith / nth0;
+
+        const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
+        const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
+
+        const int64_t ir010 = dr0*ith0;
+        const int64_t ir011 = MIN(ir010 + dr0, nr0);
+
+        const int64_t ir110 = dr1*ith1;
+        const int64_t ir111 = MIN(ir110 + dr1, nr1);
+
+        // threads with no work simply yield (not sure if it helps)
+        //if (ir010 >= ir011 || ir110 >= ir111) {
+        //    sched_yield();
+        //    continue;
+        //}
+
+        // block-tiling attempt
+        const int64_t blck_0 = 16;
+        const int64_t blck_1 = 16;
+
+        // attempt to reduce false-sharing (does not seem to make a difference)
+        float tmp[16];
+
+        for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
+            for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
+                for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
+                    const int64_t _i12 = ir1; // logical row index for this expert
+
+                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+                    const int id       = row_mapping.i1; // selected expert index
+
+                    const int64_t  i11 = id % ne11;
+                    const int64_t  i12 = row_mapping.i2; // row index in src1
+
+                    const int64_t  i1 = id;  // selected expert index
+                    const int64_t  i2 = i12; // row
+
+                    // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                    //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                    //       the original src1 data pointer, so we should index using the indices directly
+                    // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                    const char * src1_col = (const char *) wdata +
+                        (src1_cont || src1->type != vec_dot_type
+                        ? (i11      + i12*ne11)*row_size
+                        : (i11*nb11 + i12*nb12));
+
+                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
+
+                    //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
+                    //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
+                    //}
+
+                    for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
+                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
+                    }
+
+                    memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
+                }
+            }
+        }
+    }
+
+#undef MMID_MATRIX_ROW
+}
+
+// ggml_compute_forward_out_prod
+
+static void ggml_compute_forward_out_prod_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_ASSERT(ne0  == ne00);
+    GGML_ASSERT(ne1  == ne10);
+    GGML_ASSERT(ne2  == ne02);
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne3  == ne13);
+    GGML_ASSERT(ne03 == ne13);
+
+    // we don't support permuted src0 or src1
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    if (ith == 0) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    }
+    ggml_barrier(params->threadpool);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // block-tiling attempt
+    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
+    const int64_t blck_1 = 16;
+
+    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
+        const int64_t bir1 = MIN(bir + blck_1, ir1);
+        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
+            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
+            for (int64_t ir = bir; ir < bir1; ++ir) {
+                // dst indices
+                const int64_t i3 = ir/(ne2*ne1);
+                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+                const int64_t i02 = i2;
+                const int64_t i03 = i3;
+
+                //const int64_t i10 = i1;
+                const int64_t i12 = i2;
+                const int64_t i13 = i3;
+
+#if GGML_VEC_MAD_UNROLL > 2
+                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
+                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
+                }
+                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#else
+                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
+                    const int64_t i11 = i01;
+
+                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+                    ggml_vec_mad_f32(ne0, d, s0, *s1);
+                }
+#endif
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_out_prod_q_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    GGML_ASSERT(ne02 == ne12);
+    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne2  == ne12);
+    GGML_ASSERT(ne3  == ne13);
+
+    // we don't support permuted src0 dim0
+    GGML_ASSERT(nb00 == ggml_type_size(type));
+
+    // dst dim0 cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    // GGML_ASSERT(nb0 <= nb1);
+    // GGML_ASSERT(nb1 <= nb2);
+    // GGML_ASSERT(nb2 <= nb3);
+
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne02);
+    GGML_ASSERT(ne3 == ne03);
+
+    // nb01 >= nb00 - src0 is not transposed
+    //   compute by src0 rows
+
+    if (ith == 0) {
+        ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0);
+    }
+    ggml_barrier(params->threadpool);
+
+    // parallelize by last three dimensions
+
+    // total rows in dst
+    const int64_t nr = ne1*ne2*ne3;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    // dst[:,:,:,:] = 0
+    // for i2,i3:
+    //   for i1:
+    //     for i01:
+    //       for i0:
+    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
+
+    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        // dst indices
+        const int64_t i3 = ir/(ne2*ne1);
+        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
+        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        const int64_t i02 = i2;
+        const int64_t i03 = i3;
+
+        //const int64_t i10 = i1;
+        const int64_t i12 = i2;
+        const int64_t i13 = i3;
+
+        for (int64_t i01 = 0; i01 < ne01; ++i01) {
+            const int64_t i11 = i01;
+
+            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
+            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+
+            dequantize_row_q(s0, wdata, ne0);
+            ggml_vec_mad_f32(ne0, d, wdata, *s1);
+        }
+    }
+}
+
+static void ggml_compute_forward_out_prod(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_out_prod_q_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                GGML_ABORT("fatal error"); // todo
+                // ggml_compute_forward_out_prod_f16_f32(params, dst);
+            }
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_out_prod_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_scale
+
+static void ggml_compute_forward_scale_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    // scale factor
+    float v;
+    memcpy(&v, dst->op_params, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb1 = dst->nb[1];
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        if (dst->data != src0->data) {
+            // src0 is same shape as dst => same indices
+            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+        }
+        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
+    }
+}
+
+static void ggml_compute_forward_scale(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_scale_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_set
+
+static void ggml_compute_forward_set_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
+static void ggml_compute_forward_set_i32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(int32_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_i32(nc,
+                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
+static void ggml_compute_forward_set(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_set_f32(params, dst);
+            } break;
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_set_i32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cpy
+
+static void ggml_compute_forward_cpy(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_cont
+
+static void ggml_compute_forward_cont(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    ggml_compute_forward_dup(params, dst);
+}
+
+// ggml_compute_forward_reshape
+
+static void ggml_compute_forward_reshape(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_view
+
+static void ggml_compute_forward_view(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_permute
+
+static void ggml_compute_forward_permute(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_transpose
+
+static void ggml_compute_forward_transpose(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * dst) {
+    // NOP
+    UNUSED(params);
+    UNUSED(dst);
+}
+
+// ggml_compute_forward_get_rows
+
+static void ggml_compute_forward_get_rows_q(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == ggml_type_size(type));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_fp16_t));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_fp16_to_fp32_row(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_bf16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(ggml_bf16_t));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_bf16_to_fp32_row(
+                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
+    }
+}
+
+static void ggml_compute_forward_get_rows_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ggml_nelements(src1);
+
+    assert(ne0  == nc);
+    assert(ne02 == ne11);
+    assert(nb00 == sizeof(float));
+    assert(ggml_nrows(dst) == nr);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i = ir0; i < ir1; ++i) {
+        const int64_t i12 = i/(ne11*ne10);
+        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
+        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
+        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
+
+        ggml_vec_cpy_f32(nc,
+                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
+                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+    }
+}
+
+static void ggml_compute_forward_get_rows(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+            {
+                ggml_compute_forward_get_rows_q(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_f16(params, dst);
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_get_rows_bf16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_get_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_get_rows_back
+
+static void ggml_compute_forward_get_rows_back_f32_f16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    memset(dst->data, 0, ggml_nbytes(dst));
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        for (int j = 0; j < nc; ++j) {
+            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
+            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v);
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rows_back_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
+
+    memset(dst->data, 0, ggml_nbytes(dst));
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nelements(src1);
+
+    GGML_ASSERT( dst->ne[0] == nc);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < nr; ++i) {
+        const int r = ((int32_t *) src1->data)[i];
+
+        ggml_vec_add_f32(nc,
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *)  dst->data + r*dst->nb[1]),
+                (float *) ((char *) src0->data + i*src0->nb[1]));
+    }
+}
+
+static void ggml_compute_forward_get_rows_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_get_rows_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    //static bool first = true;
+    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
+    //if (first) {
+    //    first = false;
+    //} else {
+    //    for (int k = 0; k < dst->ne[1]; ++k) {
+    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
+    //            for (int i = 0; i < 16; ++i) {
+    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
+    //            }
+    //            printf("\n");
+    //        }
+    //        printf("\n");
+    //    }
+    //    printf("\n");
+    //    exit(0);
+    //}
+}
+
+// ggml_compute_forward_diag
+
+static void ggml_compute_forward_diag_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(ne00 == ne0);
+    GGML_ASSERT(ne00 == ne1);
+    GGML_ASSERT(ne01 == 1);
+    GGML_ASSERT(ne02 == ne2);
+    GGML_ASSERT(ne03 == ne3);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    for (int i3 = 0; i3 < ne3; i3++) {
+        for (int i2 = 0; i2 < ne2; i2++) {
+            for (int i1 = 0; i1 < ne1; i1++) {
+                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
+                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
+                for (int i0 = 0; i0 < i1; i0++) {
+                    d[i0] = 0;
+                }
+                d[i1] = s[i1];
+                for (int i0 = i1+1; i0 < ne0; i0++) {
+                    d[i0] = 0;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_diag(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_diag_mask_inf
+
+static void ggml_compute_forward_diag_mask_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const float value) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int  n_past  = ((int32_t *) dst->op_params)[0];
+    const bool inplace = src0->data == dst->data;
+
+    GGML_ASSERT(n_past >= 0);
+
+    if (!inplace) {
+        if (ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    // TODO: handle transposed/permuted matrices
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+    const int nr = src0->ne[1];
+    const int nz = n/nr;
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int k = 0; k < nz; k++) {
+        for (int j = ith; j < nr; j += nth) {
+            for (int i = n_past; i < nc; i++) {
+                if (i > n_past + j) {
+                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_diag_mask_inf(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_diag_mask_zero(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_diag_mask_f32(params, dst, 0);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_soft_max
+
+static void ggml_compute_forward_soft_max_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
+    // TODO: is this supposed to be ceil instead of floor?
+    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+    const uint32_t n_head      = ne02;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        // ALiBi
+        const uint32_t h = (i1/ne01)%ne02; // head
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+
+        // broadcast the mask across rows
+        ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+        float       * mp_f32 = src1 ? (float       *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
+
+        ggml_vec_cpy_f32  (nc, wp, sp);
+        ggml_vec_scale_f32(nc, wp, scale);
+        if (mp_f32) {
+            if (use_f16) {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
+                }
+            } else {
+                for (int i = 0; i < nc; ++i) {
+                    wp[i] += slope*mp_f32[i];
+                }
+            }
+        }
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(wp[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, wp);
+
+        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
+        assert(sum > 0.0);
+
+        sum = 1.0/sum;
+        ggml_vec_scale_f32(nc, dp, sum);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dp[i]));
+            assert(!isinf(dp[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_soft_max_back
+
+static void ggml_compute_forward_soft_max_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src1, dst));
+
+    // TODO: handle transposed/permuted matrices
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(dy[i]));
+            assert(!isnan(y[i]));
+        }
+#endif
+        // Jii = yi - yi*yi
+        // Jij = -yi*yj
+        // J = diag(y)-y.T*y
+        // dx = J * dy
+        // dxk = sum_i(Jki * dyi)
+        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
+        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
+        // dxk = -yk * dot(y, dy) + yk*dyk
+        // dxk = yk * (- dot(y, dy) + dyk)
+        // dxk = yk * (dyk - dot(y, dy))
+        //
+        // post-order:
+        // dot_y_dy := dot(y, dy)
+        // dx := dy
+        // dx := dx - dot_y_dy
+        // dx := dx * y
+
+        // linear runtime, no additional memory
+        float dot_y_dy = 0;
+        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+        ggml_vec_cpy_f32 (nc, dx, dy);
+        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32 (nc, dx, dx, y);
+
+#ifndef NDEBUG
+        for (int i = 0; i < nc; ++i) {
+            assert(!isnan(dx[i]));
+            assert(!isinf(dx[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_soft_max_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_soft_max_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_clamp
+
+static void ggml_compute_forward_clamp_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    float min;
+    float max;
+    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    for (int j = ith; j < n; j += nth) {
+        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
+        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+
+        for (int i = 0; i < nc; i++) {
+            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
+        }
+    }
+}
+
+static void ggml_compute_forward_clamp(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_clamp_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_TQ1_0:
+        case GGML_TYPE_TQ2_0:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_Q8_K:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_I64:
+        case GGML_TYPE_F64:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rope
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+    return 1 - MIN(1, MAX(0, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+static void ggml_rope_cache_init(
+     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta = theta_base;
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+        rope_yarn(
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta *= theta_scale;
+    }
+}
+
+static void ggml_mrope_cache_init(
+     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects,
+     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
+     float * cache, float sin_sign, float theta_scale) {
+    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
+    float theta_t = theta_base_t;
+    float theta_h = theta_base_h;
+    float theta_w = theta_base_w;
+    float theta_e = theta_base_e;  // extra position id for vision encoder
+    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
+    int sec_w = sections[1] + sections[0];
+    int sec_e = sections[2] + sec_w;
+    GGML_ASSERT(sect_dims <= ne0);
+
+    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
+        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
+
+        int sector = (i0 / 2) % sect_dims;
+        if (indep_sects) {
+            // compute theta independently for each dim sections
+            // (i.e. reset corresponding theta when `i0` go from one section to another)
+            if (sector == 0) {
+                theta_t = theta_base_t;
+            }
+            else if (sector == sections[0]) {
+                theta_h = theta_base_h;;
+            }
+            else if (sector == sec_w) {
+                theta_w = theta_base_w;
+            }
+            else if (sector == sec_e) {
+                theta_e = theta_base_e;
+            }
+        }
+
+        float theta = theta_t;
+        if (sector >= sections[0] && sector < sec_w) {
+            theta = theta_h;
+        }
+        else if (sector >= sec_w && sector < sec_w + sections[2]) {
+            theta = theta_w;
+        }
+        else if (sector >= sec_w + sections[2]) {
+            theta = theta_e;
+        }
+
+        rope_yarn(
+            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
+        );
+        cache[i0 + 1] *= sin_sign;
+
+        theta_t *= theta_scale;
+        theta_w *= theta_scale;
+        theta_h *= theta_scale;
+        theta_e *= theta_scale;
+    }
+}
+
+static void ggml_compute_forward_rope_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const bool forward) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, multimodal rotary position embedding
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
+    }
+
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
+        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            if (!is_mrope) {
+                const int64_t p = pos[i2];
+                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+            else {
+                const int64_t p_t = pos[i2];
+                const int64_t p_h = pos[i2 + ne2];
+                const int64_t p_w = pos[i2 + ne2 * 2];
+                const int64_t p_e = pos[i2 + ne2 * 3];
+                ggml_mrope_cache_init(
+                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+
+            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                if (is_neox || is_mrope) {
+                    if (is_vision){
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims];
+
+                            dst_data[0]      = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
+                        }
+                    } else {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = src[0];
+                            const float x1 = src[n_dims/2];
+
+                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
+                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+                        }
+                    }
+                } else {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[1];
+
+                        dst_data[0] = x0*cos_theta - x1*sin_theta;
+                        dst_data[1] = x0*sin_theta + x1*cos_theta;
+                    }
+                }
+
+                if (is_vision) {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const int64_t ic = i0/2;
+
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                        const float x0 = src[0];
+                        const float x1 = src[n_dims];
+
+                        dst_data[0]      = x0*cos_theta - x1*sin_theta;
+                        dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
+                    }
+                } else {
+                    // fill the remain channels with data from src tensor
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        dst_data[0] = src[0];
+                        dst_data[1] = src[1];
+                    }
+                }
+            }
+        }
+    }
+}
+
+// TODO: deduplicate f16/f32 code
+static void ggml_compute_forward_rope_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const bool forward) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
+
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(§ions,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
+    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
+
+    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(dst);
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // row index used to determine which thread to use
+    int ir = 0;
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
+
+    const float * freq_factors = NULL;
+    if (src2 != NULL) {
+        GGML_ASSERT(src2->type == GGML_TYPE_F32);
+        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
+        freq_factors = (const float *) src2->data;
+    }
+
+    // backward process uses inverse rotation by cos and sin.
+    // cos and sin build a rotation matrix, where the inverse is the transpose.
+    // this essentially just switches the sign of sin.
+    const float sin_sign = forward ? 1.0f : -1.0f;
+
+    const int32_t * pos = (const int32_t *) src1->data;
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+
+            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
+            if (!is_mrope) {
+                const int64_t p = pos[i2];
+                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+            else {
+                const int64_t p_t = pos[i2];
+                const int64_t p_h = pos[i2 + ne2];
+                const int64_t p_w = pos[i2 + ne2 * 2];
+                const int64_t p_e = pos[i2 + ne2 * 3];
+                ggml_mrope_cache_init(
+                    p_t, p_h, p_w, p_e, sections, is_vision,
+                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
+            }
+
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                if (ir++ < ir0) continue;
+                if (ir   > ir1) break;
+
+                if (is_neox || is_mrope) {
+                    if (is_vision) {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
+
+                            dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
+                    } else {
+                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                            const int64_t ic = i0/2;
+
+                            const float cos_theta = cache[i0 + 0];
+                            const float sin_theta = cache[i0 + 1];
+
+                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                            const float x0 = GGML_FP16_TO_FP32(src[0]);
+                            const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
+
+                            dst_data[0]        = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                            dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                        }
+                    }
+                } else {
+                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[1]);
+
+                        dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                }
+
+                if (is_vision) {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const int64_t ic = i0/2;
+
+                        const float cos_theta = cache[i0 + 0];
+                        const float sin_theta = cache[i0 + 1];
+
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+                        const float x0 = GGML_FP16_TO_FP32(src[0]);
+                        const float x1 = GGML_FP16_TO_FP32(src[n_dims]);
+
+                        dst_data[0]      = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+                        dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
+                    }
+                } else {
+                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
+                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+                        dst_data[0] = src[0];
+                        dst_data[1] = src[1];
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_rope(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, dst, true);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, dst, true);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rope_back
+
+static void ggml_compute_forward_rope_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_rope_f16(params, dst, false);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rope_f32(params, dst, false);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_conv_transpose_1d
+
+static void ggml_compute_forward_conv_transpose_1d_f16_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (L x Cin) to (Cin x L)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            ggml_fp16_t * dst_data = wdata;
+
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]);
+                }
+            }
+        }
+
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+    // total rows in dst
+    const int nr = ne1;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f16(ne02, &v, 0,
+                        (ggml_fp16_t *)    wdata_src + i1n, 0,
+                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
+                dst_data[i10*s0 + i00] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_transpose_1d_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02;
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
+        {
+            float * const wdata = (float *) params->wdata + 0;
+
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                for (int64_t i01 = 0; i01 < ne01; i01++) {
+                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
+                    float * dst_data = wdata + i01*ne00*ne02;
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        dst_data[i00*ne02 + i02] = src[i00];
+                    }
+                }
+            }
+        }
+
+        // prepare source data (src1)
+        {
+            float * const wdata = (float *) params->wdata + nk;
+            float * dst_data = wdata;
+
+            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                const float * const src = (float *)((char *) src1->data + i11*nb11);
+                for (int64_t i10 = 0; i10 < ne10; i10++) {
+                    dst_data[i10*ne11 + i11] = src[i10];
+                }
+            }
+        }
+
+        // need to zero dst since we are accumulating into it
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+
+    // total rows in dst
+    const int nr = ne1;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float * const wdata     = (float *) params->wdata + 0;
+    float * const wdata_src = wdata + nk;
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * dst_data = (float *)((char *) dst->data + i1*nb1);
+        float * wdata_kernel = wdata + i1*ne02*ne00;
+        for (int i10 = 0; i10 < ne10; i10++) {
+            const int i1n = i10*ne11;
+            for (int i00 = 0; i00 < ne00; i00++) {
+                float v = 0;
+                ggml_vec_dot_f32(ne02, &v, 0,
+                        wdata_src + i1n, 0,
+                        wdata_kernel + i00*ne02, 0, 1);
+                dst_data[i10*s0 + i00] += v;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_conv_transpose_1d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_im2col_f32
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+
+// ggml_compute_forward_im2col_f16
+// src0: kernel [OC, IC, KH, KW]
+// src1: image [N, IC, IH, IW]
+// dst:  result [N, OH, OW, IC*KH*KW]
+static void ggml_compute_forward_im2col_f16(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne13 : ne12;
+    const int64_t IC = is_2D ? ne12 : ne11;
+    const int64_t IH = is_2D ? ne11 : 1;
+    const int64_t IW = ne10;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    int ofs0 = is_2D ? nb13 : nb12;
+    int ofs1 = is_2D ? nb12 : nb11;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
+                for (int64_t iow = 0; iow < OW; iow++) {
+                    for (int64_t iic = ith; iic < IC; iic += nth) {
+
+                        // micro kernel
+                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
+                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
+                                } else {
+                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_im2col(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_im2col_f16(params, dst);
+            } break;
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_im2col_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_im2col_back_f32
+
+static void ggml_compute_forward_im2col_back_f32(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t N  = is_2D ? ne3 : ne2;
+    const int64_t IC = is_2D ? ne2 : ne1;
+    const int64_t IH = is_2D ? ne1 : 1;
+    const int64_t IW = ne0;
+
+    const int64_t KH = is_2D ? ne01 : 1;
+    const int64_t KW = ne00;
+
+    const int64_t OH = is_2D ? ne12 : 1;
+    const int64_t OW = ne11;
+
+    int ofs0 = is_2D ? nb3 : nb2;
+    int ofs1 = is_2D ? nb2 : nb1;
+
+    GGML_ASSERT(nb0  == sizeof(float));
+
+    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+    {
+        float * const wdata = (float *) dst->data;
+
+        for (int64_t in = 0; in < N; in++) {
+            for (int64_t iic = ith; iic < IC; iic += nth) {
+                for (int64_t iih = 0; iih < IH; iih++) {
+                    for (int64_t iiw = 0; iiw < IW; iiw++) {
+
+                        // micro kernel
+                        float grad = 0.0f;
+                        for (int64_t ikh = 0; ikh < KH; ikh++) {
+                            for (int64_t ikw = 0; ikw < KW; ikw++) {
+                                // For s0 > 1 some values were skipped over in the forward pass.
+                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
+                                const int64_t tmpw = (iiw + p0 - ikw*d0);
+                                if (tmpw % s0 != 0) {
+                                    continue;
+                                }
+                                const int64_t iow = tmpw / s0;
+
+                                // Equivalent logic as above except for s1.
+                                int64_t ioh;
+                                if (is_2D) {
+                                    const int64_t tmph = iih + p1 - ikh*d1;
+
+                                    if (tmph % s1 != 0) {
+                                        continue;
+                                    }
+
+                                    ioh = tmph / s1;
+                                } else {
+                                    ioh = 0;
+                                }
+
+                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
+                                    continue;
+                                }
+
+                                const float * const src_data = (const float *) src1->data
+                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
+                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+                            }
+                        }
+                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
+                        dst_data[iih*IW + iiw] = grad;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_conv_transpose_2d
+
+static void ggml_compute_forward_conv_transpose_2d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nk = ne00*ne01*ne02*ne03;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    if (ith == 0) {
+        memset(params->wdata, 0, params->wsize);
+
+        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+
+            for (int64_t i03 = 0; i03 < ne03; i03++) {
+                for (int64_t i02 = 0; i02 < ne02; i02++) {
+                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
+                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
+                    for (int64_t i01 = 0; i01 < ne01; i01++) {
+                        for (int64_t i00 = 0; i00 < ne00; i00++) {
+                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
+                        }
+                    }
+                }
+            }
+        }
+
+        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
+        {
+            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
+            for (int i12 = 0; i12 < ne12; i12++) {
+                for (int i11 = 0; i11 < ne11; i11++) {
+                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
+                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
+                    for (int i10 = 0; i10 < ne10; i10++) {
+                        dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]);
+                    }
+                }
+            }
+        }
+
+        memset(dst->data, 0, ggml_nbytes(dst));
+    }
+    ggml_barrier(params->threadpool);
+
+    const int32_t stride = ggml_get_op_params_i32(dst, 0);
+
+    // total patches in dst
+    const int np = ne2;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
+    ggml_fp16_t * const wdata_src = wdata + nk;
+
+    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
+        float * dst_data = (float *)((char *) dst->data + i2*nb2);
+        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
+        for (int i11 = 0; i11 < ne11; i11++) {
+            for (int i10 = 0; i10 < ne10; i10++) {
+                const int i1n = i11*ne10*ne12 + i10*ne12;
+                for (int i01 = 0; i01 < ne01; i01++) {
+                    for (int i00 = 0; i00 < ne00; i00++) {
+                        float v = 0;
+                        ggml_vec_dot_f16(ne03, &v, 0,
+                                wdata_src + i1n, 0,
+                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
+                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
+                    }
+                }
+            }
+        }
+    }
+}
+
+// ggml_compute_forward_pool_1d_sk_p0
+
+static void ggml_compute_forward_pool_1d_sk_p0(
+        const struct ggml_compute_params * params,
+        const enum ggml_op_pool op,
+        const int k,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src = dst->src[0];
+
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const char * cdata = (const char *)src->data;
+    const char * const data_end = cdata + ggml_nbytes(src);
+    float * drow = (float *)dst->data;
+
+    const int64_t rs = dst->ne[0];
+
+    while (cdata < data_end) {
+        const void * srow = (const void *)cdata;
+        int j = 0;
+        for (int64_t i = 0; i < rs; ++i) {
+            switch (op) {
+                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
+                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
+                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+            }
+            for (int ki = 0; ki < k; ++ki) {
+                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                switch (op) {
+                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
+                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
+                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
+                }
+                ++j;
+            }
+            switch (op) {
+                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
+                case GGML_OP_POOL_MAX:                       break;
+                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+            }
+        }
+
+        cdata += src->nb[1];
+        drow  += rs;
+    }
+}
+
+// ggml_compute_forward_pool_1d
+
+static void ggml_compute_forward_pool_1d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int s0 = opts[2];
+    const int p0 = opts[3];
+    GGML_ASSERT(p0 == 0); // padding not supported
+    GGML_ASSERT(k0 == s0); // only s = k supported
+
+    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+}
+
+// ggml_compute_forward_pool_2d
+
+static void ggml_compute_forward_pool_2d(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src = dst->src[0];
+
+    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+    const char * cdata = (const char*)src->data;
+    const char * const data_end = cdata + ggml_nbytes(src);
+
+    const int64_t px = dst->ne[0];
+    const int64_t py = dst->ne[1];
+    const int64_t pa = px * py;
+
+    float * dplane = (float *)dst->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            float * const drow = dplane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                float * const out =  drow + ox;
+                switch (op) {
+                    case GGML_OP_POOL_AVG:     *out = 0;        break;
+                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                for (int ky = 0; ky < k1; ++ky) {
+                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
+                    for (int kx = 0; kx < k0; ++kx) {
+                        int j = ix + kx;
+                        if (j < 0 || j >= src->ne[0]) continue;
+                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                        switch (op) {
+                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
+                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
+                        }
+                    }
+                }
+                switch (op) {
+                    case GGML_OP_POOL_AVG:           *out /= ka; break;
+                    case GGML_OP_POOL_MAX:                       break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
+                }
+            }
+        }
+
+        cdata  += src->nb[2];
+        dplane += pa;
+    }
+}
+
+// ggml_compute_forward_pool_2d_back
+
+static void ggml_compute_forward_pool_2d_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src  = dst->src[0];
+    const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
+
+    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = opts[0];
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    char       * cdata  = (char       *) dst->data;
+    const char * cdataf = (const char *) dstf->data;
+    const char * const data_end = cdata + ggml_nbytes(dst);
+
+    GGML_ASSERT(params->ith == 0);
+    memset(cdata, 0, ggml_nbytes(dst));
+
+    const int64_t px = src->ne[0];
+    const int64_t py = src->ne[1];
+    const int64_t pa = px * py;
+
+    const float * splane = (const float *) src->data;
+
+    const int ka = k0 * k1;
+    const int offset0 = -p0;
+    const int offset1 = -p1;
+
+    while (cdata < data_end) {
+        for (int oy = 0; oy < py; ++oy) {
+            const float * const srow = splane + oy * px;
+            for (int ox = 0; ox < px; ++ox) {
+                const float grad0 = srow[ox];
+
+                const int ix = offset0 + ox * s0;
+                const int iy = offset1 + oy * s1;
+
+                if (op == GGML_OP_POOL_MAX) {
+                    float maxval = -FLT_MAX;
+                    int kxmax = -1;
+                    int kymax = -1;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            const float val = dst->type == GGML_TYPE_F32 ?
+                                ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
+                            if (val <= maxval) {
+                                continue;
+                            }
+
+                            maxval = val;
+                            kxmax = kx;
+                            kymax = ky;
+                        }
+                    }
+
+                    if (kxmax == -1 || kymax == -1) {
+                        continue;
+                    }
+
+                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
+                    const int j = ix + kxmax;
+                    if (dst->type == GGML_TYPE_F32) {
+                        ((float *) drow)[j] += grad0;
+                    } else {
+                        ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
+                    }
+                } else if (op == GGML_OP_POOL_AVG) {
+                    const float grad = grad0 / ka;
+
+                    for (int ky = 0; ky < k1; ++ky) {
+                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
+                            continue;
+                        }
+                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
+                        for (int kx = 0; kx < k0; ++kx) {
+                            int j = ix + kx;
+                            if (j < 0 || j >= dst->ne[0]) {
+                                continue;
+                            }
+
+                            if (dst->type == GGML_TYPE_F32) {
+                                ((float *) drow)[j] += grad;
+                            } else {
+                                ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad);
+                            }
+                        }
+                    }
+                } else {
+                    GGML_ASSERT(false);
+                }
+            }
+        }
+
+        cdata  += dst->nb[2];
+        cdataf += dst->nb[2];
+        splane += pa;
+    }
+}
+
+// ggml_compute_forward_upscale
+
+static void ggml_compute_forward_upscale_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const float sf0 = (float)ne0/src0->ne[0];
+    const float sf1 = (float)ne1/src0->ne[1];
+    const float sf2 = (float)ne2/src0->ne[2];
+    const float sf3 = (float)ne3/src0->ne[3];
+
+    // TODO: optimize
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        const int64_t i03 = i3 / sf3;
+        for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+            const int64_t i02 = i2 / sf2;
+            for (int64_t i1 = 0; i1 < ne1; i1++) {
+                const int64_t i01 = i1 / sf1;
+                for (int64_t i0 = 0; i0 < ne0; i0++) {
+                    const int64_t i00 = i0 / sf0;
+
+                    const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                          float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
+
+                    *y = *x;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_upscale(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_upscale_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+    const struct ggml_compute_params * params,
+          struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    } else {
+                        dst_ptr[dst_idx] = 0;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_pad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_pad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_pad_reflect_1d
+
+static void ggml_compute_forward_pad_reflect_1d(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int32_t * opts = (const int32_t *) dst->op_params;
+    const int p0 = opts[0];
+    const int p1 = opts[1];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
+                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
+
+                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
+
+                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
+                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_unpad_f32(
+    const struct ggml_compute_params *params,
+    struct ggml_tensor *dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    float * dst_ptr = (float *) dst->data;
+
+    // TODO: optimize
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                for (int64_t i3 = 0; i3 < ne3; ++i3) {
+                    const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+                    const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+                        dst_ptr[dst_idx] = *src_ptr;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_unpad(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_unpad_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_arange
+
+static void ggml_compute_forward_arange_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    GGML_ASSERT(dst->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const float start = ggml_get_op_params_f32(dst, 0);
+    const float stop  = ggml_get_op_params_f32(dst, 1);
+    const float step  = ggml_get_op_params_f32(dst, 2);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    for (int64_t i = ith; i < steps; i+= nth) {
+        float value = start + step * i;
+        ((float *)dst->data)[i] = value;
+    }
+}
+
+static void ggml_compute_forward_arange(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+    switch (dst->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_arange_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int dim = ggml_get_op_params_i32(dst, 0);
+    const int max_period = ggml_get_op_params_i32(dst, 1);
+
+    int half = dim / 2;
+
+    for (int64_t i = 0; i < ne00; i++) {
+        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
+        for (int64_t j = ith; j < half; j += nth) {
+            float timestep = ((float *)src0->data)[i];
+            float freq = (float)expf(-logf(max_period) * j / half);
+            float arg = timestep * freq;
+            embed_data[j] = cosf(arg);
+            embed_data[j + half] = sinf(arg);
+        }
+        if (dim % 2 != 0 && ith == 0) {
+            embed_data[dim] = 0.f;
+        }
+    }
+}
+
+static void ggml_compute_forward_timestep_embedding(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_timestep_embedding_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_argsort
+
+static void ggml_compute_forward_argsort_f32(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT(nb0 == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nr = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+    for (int64_t i = ith; i < nr; i += nth) {
+        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+        const float * src_data = (float *)((char *) src0->data + i*nb01);
+
+        for (int64_t j = 0; j < ne0; j++) {
+            dst_data[j] = j;
+        }
+
+        // C doesn't have a functional sort, so we do a bubble sort instead
+        for (int64_t j = 0; j < ne0; j++) {
+            for (int64_t k = j + 1; k < ne0; k++) {
+                if ((order == GGML_SORT_ORDER_ASC  && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+                    (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+                    int32_t tmp = dst_data[j];
+                    dst_data[j] = dst_data[k];
+                    dst_data[k] = tmp;
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_argsort(
+    const struct ggml_compute_params * params,
+    struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_argsort_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_flash_attn_ext
+
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+
+    GGML_ASSERT(ne0 == D);
+    GGML_ASSERT(ne2 == N);
+
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nev0 == D);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // broadcast factors
+    const int64_t rk2 = neq2/nek2;
+    const int64_t rk3 = neq3/nek3;
+
+    const int64_t rv2 = neq2/nev2;
+    const int64_t rv3 = neq3/nev3;
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    enum ggml_type    const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type;
+    ggml_from_float_t const q_to_vec_dot   = type_traits_cpu[k_vec_dot_type].from_float;
+    ggml_vec_dot_t    const kq_vec_dot     = type_traits_cpu[k->type].vec_dot;
+    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
+
+    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
+    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
+
+    // loop over n_batch and n_head
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int iq3 = ir/(neq2*neq1);
+        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+        const uint32_t h = iq2; // head index
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+        float S = 0.0f;      // sum
+        float M = -INFINITY; // maximum KQ value
+
+        float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+        float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
+        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
+        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
+
+        if (v->type == GGML_TYPE_F16) {
+            memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+        } else {
+            memset(VKQ32, 0, D*sizeof(float));
+        }
+
+        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
+
+        // k indices
+        const int ik3 = iq3 / rk3;
+        const int ik2 = iq2 / rk2;
+
+        // v indices
+        const int iv3 = iq3 / rv3;
+        const int iv2 = iq2 / rv2;
+
+        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+        q_to_vec_dot(pq, Q_q, D);
+
+        // online softmax / attention
+        // loop over n_kv and n_head_kv
+        // ref: https://arxiv.org/pdf/2112.05682.pdf
+        for (int64_t ic = 0; ic < nek1; ++ic) {
+            const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+            if (mv == -INFINITY) {
+                continue;
+            }
+
+            float s; // KQ value
+
+            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
+            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
+
+            s = s*scale; // scale KQ value
+
+            if (logit_softcap != 0.0f) {
+                s = logit_softcap*tanhf(s);
+            }
+
+            s += mv; // apply mask
+
+            const float Mold = M;
+
+            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
+            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
+
+            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+
+            if (v->type == GGML_TYPE_F16) {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
+
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f16(D, VKQ16, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
+
+                // V += v*expf(s - M)
+                ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
+            } else {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
+
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f32(D, VKQ32, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
+
+                v_to_float(v_data, V32, D);
+
+                // V += v*expf(s - M)
+                ggml_vec_mad_f32(D, VKQ32, V32, vs);
+            }
+
+            S = S*ms + vs; // scale and increment sum with partial sum
+        }
+
+        if (v->type == GGML_TYPE_F16) {
+            for (int64_t d = 0; d < D; ++d) {
+                VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
+            }
+        }
+
+        // V /= S
+        const float S_inv = 1.0f/S;
+        ggml_vec_scale_f32(D, VKQ32, S_inv);
+
+        // dst indices
+        const int i1 = iq1;
+        const int i2 = iq2;
+        const int i3 = iq3;
+
+        // original
+        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+
+        // permute(0, 2, 1, 3)
+        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+    }
+}
+
+static void ggml_compute_forward_flash_attn_ext(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * q,
+        const struct ggml_tensor * k,
+        const struct ggml_tensor * v,
+        const struct ggml_tensor * mask,
+        struct ggml_tensor * dst) {
+    switch (dst->op_params[3]) {
+        case GGML_PREC_DEFAULT:
+        case GGML_PREC_F32:
+            {
+                // uses F32 accumulators
+                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_flash_attn_back
+
+static void ggml_compute_forward_flash_attn_back_f32(
+        const struct ggml_compute_params * params,
+        const bool masked,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * q = dst->src[0];
+    const struct ggml_tensor * k = dst->src[1];
+    const struct ggml_tensor * v = dst->src[2];
+    const struct ggml_tensor * d = dst->src[3];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t D = neq0;
+    const int64_t N = neq1;
+    const int64_t P = nek1 - N;
+    const int64_t M = P + N;
+
+    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
+    const int mxDM = MAX(D, Mup);
+
+    // GGML_ASSERT(ne0 == D);
+    // GGML_ASSERT(ne1 == N);
+    GGML_ASSERT(P >= 0);
+
+    GGML_ASSERT(nbq0 == sizeof(float));
+    GGML_ASSERT(nbk0 == sizeof(float));
+    GGML_ASSERT(nbv0 == sizeof(float));
+
+    GGML_ASSERT(neq0 == D);
+    GGML_ASSERT(nek0 == D);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned0 == D);
+
+    GGML_ASSERT(neq1 == N);
+    GGML_ASSERT(nek1 == N + P);
+    GGML_ASSERT(nev1 == D);
+    GGML_ASSERT(ned1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    if (ith == 0) {
+        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
+    }
+    ggml_barrier(params->threadpool);
+
+    const int64_t elem_q = ggml_nelements(q);
+    const int64_t elem_k = ggml_nelements(k);
+
+    enum ggml_type result_type = dst->type;
+    GGML_ASSERT(ggml_blck_size(result_type) == 1);
+    const size_t tsize = ggml_type_size(result_type);
+
+    const size_t offs_q = 0;
+    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+
+    void * grad_q = (char *) dst->data;
+    void * grad_k = (char *) dst->data + offs_k;
+    void * grad_v = (char *) dst->data + offs_v;
+
+    const size_t nbgq1 = nb0*neq0;
+    const size_t nbgq2 = nb0*neq0*neq1;
+    const size_t nbgq3 = nb0*neq0*neq1*neq2;
+
+    const size_t nbgk1 = nb0*nek0;
+    const size_t nbgk2 = nb0*nek0*nek1;
+    const size_t nbgk3 = nb0*nek0*nek1*neq2;
+
+    const size_t nbgv1 = nb0*nev0;
+    const size_t nbgv2 = nb0*nev0*nev1;
+    const size_t nbgv3 = nb0*nev0*nev1*neq2;
+
+    // parallelize by k rows using ggml_vec_dot_f32
+
+    // total rows in k
+    const int nr = nek2*nek3;
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float scale = 1.0f/sqrtf(D);
+
+    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
+
+    // how often k2 (and v2) is repeated in q2
+    int nrep = neq2/nek2;
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // q indices
+        const int ik3 = ir/(nek2);
+        const int ik2 = ir - ik3*nek2;
+
+        const int iq3 = ik3;
+        const int id3 = ik3;
+        const int iv3 = ik3;
+        const int iv2 = ik2;
+
+        for (int irep = 0; irep < nrep; ++irep) {
+            const int iq2 = ik2 + irep*nek2;
+            const int id2 = iq2;
+
+            // (ik2 + irep*nek2) % nek2 == ik2
+            for (int iq1 = 0; iq1 < neq1; ++iq1) {
+                const int id1 = iq1;
+
+                // not sure about CACHE_LINE_SIZE_F32..
+                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
+                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
+                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
+
+                for (int i = M; i < Mup; ++i) {
+                    S[i] = -INFINITY;
+                }
+
+                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    // k indices
+                    const int ik1 = ic;
+
+                    // S indices
+                    const int i1 = ik1;
+
+                    ggml_vec_dot_f32(neq0,
+                            S + i1, 0,
+                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
+                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
+                }
+
+                // scale
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                for (int64_t i = masked_begin; i < M; i++) {
+                    S[i] = -INFINITY;
+                }
+
+                // softmax
+                // exclude known -INF S[..] values from max and loop
+                // dont forget to set their SM values to zero
+                {
+                    float max = -INFINITY;
+                    ggml_vec_max_f32(masked_begin, &max, S);
+
+                    ggml_float sum = 0.0;
+                    {
+#ifdef GGML_SOFT_MAX_ACCELERATE
+                        max = -max;
+                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
+                        vvexpf(SM, SM, &Mup);
+                        ggml_vec_sum_f32(Mup, &sum, SM);
+#else
+                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
+#endif
+                    }
+
+                    assert(sum > 0.0);
+
+                    sum = 1.0/sum;
+                    ggml_vec_scale_f32(masked_begin, SM, sum);
+
+                }
+
+                // step-by-step explanation
+                {
+                    // forward-process                    shape      grads from backward process
+                    // parallel_for ik2,ik3:
+                    //  for irep:
+                    //   iq2 = ik2 + irep*nek2
+                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
+                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
+                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
+                    //   for iq1:
+                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
+                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
+                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
+                    //    S0     = -Inf                   [D,1,1,1]
+                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
+                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
+                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
+                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
+                    //   ~S5[i]  = dot(vcur[:,i], S4)
+                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
+                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
+                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
+                    // dst                               backward-/ grad[dst]                 = d
+                    //
+                    // output gradients with their dependencies:
+                    //
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S4]   = grad[S5] @ vcur
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[vcur] = grad[S5].T @ S4
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // in post-order:
+                    //
+                    // S1         = qcur @ kcur.T
+                    // S2         = S1 * scale
+                    // S3         = diag_mask_inf(S2, P)
+                    // S4         = softmax(S3)
+                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
+                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
+                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
+                    // grad[qcur] = grad[S1]   @ kcur
+                    // grad[kcur] = grad[S1].T @ qcur
+                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
+                    //
+                    // using less variables (SM=S4):
+                    //
+                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
+                    // SM            = softmax(S)
+                    // S             = d[:D,iq1,iq2,iq3] @ vcur
+                    // dot_SM_gradSM = dot(SM, S)
+                    // S             = SM * (S - dot(SM, S))
+                    // S             = diag_mask_zero(S, P) * scale
+                    //
+                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
+                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
+                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
+                }
+
+                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
+                // for ic:
+                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
+                // exclude known future zero S[..] values from operation
+                ggml_vec_set_f32(masked_begin, S, 0);
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            S,
+                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
+                }
+
+                // S = SM * (S - dot(SM, S))
+                float dot_SM_gradSM = 0;
+                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
+                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
+                ggml_vec_mul_f32 (masked_begin, S, S, SM);
+
+                // S = diag_mask_zero(S, P) * scale
+                // already done by above ggml_vec_set_f32
+
+                // exclude known zero S[..] values from operation
+                ggml_vec_scale_f32(masked_begin, S, scale);
+
+                // S    shape [M,1]
+                // SM   shape [M,1]
+                // kcur shape [D,M]
+                // qcur shape [D,1]
+                // vcur shape [M,D]
+
+                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
+                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
+                // for ic:
+                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
+                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
+                            S[ic]);
+                }
+
+                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
+                // for ic:
+                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
+                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
+                // exclude known zero S[..] values from loop
+                for (int64_t ic = 0; ic < masked_begin; ++ic) {
+                    ggml_vec_mad_f32(D,
+                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
+                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
+                            S[ic]);
+                }
+
+                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
+                // for ic:
+                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
+                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
+                // exclude known zero SM[..] values from mad
+                for (int64_t ic = 0; ic < D; ++ic) {
+                    ggml_vec_mad_f32(masked_begin,
+                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
+                            SM,
+                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_flash_attn_back(
+        const struct ggml_compute_params * params,
+        const bool masked,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * q = dst->src[0];
+
+    switch (q->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_ssm_conv
+
+static void ggml_compute_forward_ssm_conv_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0]; // conv_x
+    const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc  = src1->ne[0]; // d_conv
+    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
+    const int nr  = src0->ne[1]; // d_inner
+    const int n_t =  dst->ne[1]; // tokens per sequence
+    const int n_s =  dst->ne[2]; // number of sequences in the batch
+
+    GGML_ASSERT( dst->ne[0] == nr);
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+    const int ir  = ir1 - ir0;
+
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            // {d_conv - 1 + n_t, d_inner, n_seqs}
+            // sliding window
+            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
+            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
+            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
+
+            // TODO: transpose the output for smaller strides for big batches?
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // rowwise dot product
+                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
+                float sumf = 0.0f;
+
+                // d_conv
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
+                }
+                x[i1] = sumf;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_ssm_conv(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    switch (dst->src[0]->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_ssm_conv_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_ssm_scan
+
+static void ggml_compute_forward_ssm_scan_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const struct ggml_tensor * src0 = dst->src[0]; // s
+    const struct ggml_tensor * src1 = dst->src[1]; // x
+    const struct ggml_tensor * src2 = dst->src[2]; // dt
+    const struct ggml_tensor * src3 = dst->src[3]; // A
+    const struct ggml_tensor * src4 = dst->src[4]; // B
+    const struct ggml_tensor * src5 = dst->src[5]; // C
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t nc  = src0->ne[0]; // d_state
+    const int64_t nr  = src0->ne[1]; // d_inner
+    const int64_t n_t = src1->ne[1]; // number of tokens per sequence
+    const int64_t n_s = src0->ne[2]; // number of sequences in the batch
+
+    GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+    GGML_ASSERT(src2->nb[0] == sizeof(float));
+    GGML_ASSERT(src3->nb[0] == sizeof(float));
+    GGML_ASSERT(src4->nb[0] == sizeof(float));
+    GGML_ASSERT(src5->nb[0] == sizeof(float));
+    // required for the dot product between s and C
+    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
+    // required for per-sequence offsets for states
+    GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
+    // required to get correct offset for state destination (i.e. src1->nb[3])
+    GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+    const int ir  = ir1 - ir0;
+
+    for (int i3 = 0; i3 < n_s; ++i3) {
+        for (int i2 = 0; i2 < n_t; ++i2) {
+            const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
+            const float * x  = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+            const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
+            const float * A  = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
+            const float * B  = (const float *) ((const char *) src4->data +  i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
+            const float * C  = (const float *) ((const char *) src5->data +  i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
+                  float * y  = (      float *) ((      char *) dst->data  + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
+                  float * s  = (      float *) ((      char *) dst->data  + ir0*(src0->nb[1]) + i3*(src0->nb[2]) +     src1->nb[3]);  // {d_state, d_inner, n_s}
+
+            // use the output as the source for the next token-wise iterations
+            if (i2 > 0) { s0 = s; }
+
+            // d_inner
+            for (int i1 = 0; i1 < ir; ++i1) {
+                // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
+                float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
+                float x_dt = x[i1] * dt_soft_plus;
+                float sumf = 0.0f;
+                // d_state
+                for (int i0 = 0; i0 < nc; ++i0) {
+                    int i = i0 + i1*nc;
+                    // state = prev_state * dA + dB * x
+                    float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+                    // y = rowwise_dotprod(state, C)
+                    sumf += state * C[i0];
+                    s[i] = state;
+                }
+                y[i1] = sumf;
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_ssm_scan(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    switch (dst->src[0]->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_ssm_scan_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_win_part
+
+static void ggml_compute_forward_win_part_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    UNUSED(params);
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+
+    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
+    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
+    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
+
+    assert(ne00 == ne0);
+    assert(ne3  == nep0*nep1);
+
+    // TODO: optimize / multi-thread
+    for (int py = 0; py < nep1; ++py) {
+        for (int px = 0; px < nep0; ++px) {
+            const int64_t i3 = py*nep0 + px;
+            for (int64_t i2 = 0; i2 < ne2; ++i2) {
+                for (int64_t i1 = 0; i1 < ne1; ++i1) {
+                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                        const int64_t i02 = py*w + i2;
+                        const int64_t i01 = px*w + i1;
+                        const int64_t i00 = i0;
+
+                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
+                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
+
+                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
+                            ((float *) dst->data)[i] = 0.0f;
+                        } else {
+                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
+                        }
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_win_part(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_part_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_win_unpart
+
+static void ggml_compute_forward_win_unpart_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    UNUSED(params);
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
+
+    const int32_t w = ((const int32_t *)(dst->op_params))[0];
+
+    // padding
+    const int px = (w - ne1%w)%w;
+    //const int py = (w - ne2%w)%w;
+
+    const int npx = (px + ne1)/w;
+    //const int npy = (py + ne2)/w;
+
+    assert(ne0 == ne00);
+
+    // TODO: optimize / multi-thread
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                const int ip2 = i2/w;
+                const int ip1 = i1/w;
+
+                const int64_t i02 = i2%w;
+                const int64_t i01 = i1%w;
+                const int64_t i00 = i0;
+
+                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
+                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
+
+                ((float *) dst->data)[j] = ((float *) src0->data)[i];
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_win_unpart(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_win_unpart_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+//gmml_compute_forward_unary
+
+static void ggml_compute_forward_unary(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const enum ggml_unary_op op = ggml_get_unary_op(dst);
+
+    switch (op) {
+        case GGML_UNARY_OP_ABS:
+            {
+                ggml_compute_forward_abs(params, dst);
+            } break;
+        case GGML_UNARY_OP_SGN:
+            {
+                ggml_compute_forward_sgn(params, dst);
+            } break;
+        case GGML_UNARY_OP_NEG:
+            {
+                ggml_compute_forward_neg(params, dst);
+            } break;
+        case GGML_UNARY_OP_STEP:
+            {
+                ggml_compute_forward_step(params, dst);
+            } break;
+        case GGML_UNARY_OP_TANH:
+            {
+                ggml_compute_forward_tanh(params, dst);
+            } break;
+        case GGML_UNARY_OP_ELU:
+            {
+                ggml_compute_forward_elu(params, dst);
+            } break;
+        case GGML_UNARY_OP_RELU:
+            {
+                ggml_compute_forward_relu(params, dst);
+            } break;
+        case GGML_UNARY_OP_SIGMOID:
+            {
+                ggml_compute_forward_sigmoid(params, dst);
+            } break;
+        case GGML_UNARY_OP_GELU:
+            {
+                ggml_compute_forward_gelu(params, dst);
+            } break;
+        case GGML_UNARY_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, dst);
+            } break;
+        case GGML_UNARY_OP_SILU:
+            {
+                ggml_compute_forward_silu(params, dst);
+            } break;
+        case GGML_UNARY_OP_HARDSWISH:
+            {
+                ggml_compute_forward_hardswish(params, dst);
+            } break;
+        case GGML_UNARY_OP_HARDSIGMOID:
+            {
+                ggml_compute_forward_hardsigmoid(params, dst);
+            } break;
+        case GGML_UNARY_OP_EXP:
+            {
+                ggml_compute_forward_exp(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_get_rel_pos
+
+static void ggml_compute_forward_get_rel_pos_f16(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    UNUSED(params);
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int64_t w = ne1;
+
+    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
+    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
+
+    for (int64_t i2 = 0; i2 < ne2; ++i2) {
+        for (int64_t i1 = 0; i1 < ne1; ++i1) {
+            const int64_t pos = (w - i1 - 1) + i2;
+            for (int64_t i0 = 0; i0 < ne0; ++i0) {
+                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_get_rel_pos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
+            {
+                ggml_compute_forward_get_rel_pos_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_add_rel_pos
+
+static void ggml_compute_forward_add_rel_pos_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src2 = dst->src[2];
+
+    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
+    if (!inplace) {
+        if (params->ith == 0) {
+            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
+
+    float * src1_data = (float *) src1->data;
+    float * src2_data = (float *) src2->data;
+    float * dst_data  = (float *) dst->data;
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // total patches in dst
+    const int np = ne13;
+
+    // patches per thread
+    const int dp = (np + nth - 1)/nth;
+
+    // patch range for this thread
+    const int ip0 = dp*ith;
+    const int ip1 = MIN(ip0 + dp, np);
+
+    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
+        for (int64_t i12 = 0; i12 < ne12; ++i12) {
+            for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
+                for (int64_t i10 = 0; i10 < ne10; ++i10) {
+                    const int64_t jp0  = jp1 + i10;
+                    const float src1_e = src1_data[jp0];
+                    const float src2_e = src2_data[jp0];
+
+                    const int64_t jdh = jp0 * ne10;
+                    const int64_t jdw = jdh - (ne10 - 1) * i10;
+
+                    for (int64_t j = 0; j < ne10; ++j) {
+                        dst_data[jdh + j     ] += src2_e;
+                        dst_data[jdw + j*ne10] += src1_e;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void ggml_compute_forward_add_rel_pos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_rel_pos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_rwkv_wkv6
+
+static void ggml_compute_forward_rwkv_wkv6_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t n_seqs = dst->src[5]->ne[1];
+    const int64_t head_size = C / HEADS;
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k =          (float *) dst->src[0]->data;
+    float * v =          (float *) dst->src[1]->data;
+    float * r =          (float *) dst->src[2]->data;
+    float * time_faaaa = (float *) dst->src[3]->data;
+    float * time_decay = (float *) dst->src[4]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define WKV_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define WKV_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define WKV_VECTOR_SIZE 4
+    #endif
+
+    #ifdef WKV_VECTOR_SIZE
+        const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
+                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
+                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * WKV_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = kv * time_faaaa + prev_state
+                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
+
+                        // Update dst: dst += temp * r
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state: state = prev_state * time_decay + kv
+                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        // basically fused operations:
+        // dst = r @ (time_faaaa * (k @ v) + state),
+        // state = time_decay * state + (k @ v),
+        // recursive through each token
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    // RWKV v6: different time_decay for each token.
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_rwkv_wkv6(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_unary
+
+static void ggml_compute_forward_map_unary_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_unary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_map_unary(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_unary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_unary_f32(params, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_binary
+
+static void ggml_compute_forward_map_binary_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(src1));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    for (int i = 0; i < n; i++) {
+        fun(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])),
+                (float *) ((char *) src1->data + i*(src1->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_map_binary(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_binary_op_f32_t fun) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_map_binary_f32(params, dst, fun);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_custom1_op_f32_t fun) {
+
+    const struct ggml_tensor * a = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    fun(dst, a);
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_custom2_op_f32_t fun) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    fun(dst, a, b);
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst,
+        const ggml_custom3_op_f32_t fun) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+    const struct ggml_tensor * c = dst->src[1];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    fun(dst, a, b, c);
+}
+
+// ggml_compute_forward_map_custom1
+
+static void ggml_compute_forward_map_custom1(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * a = dst->src[0];
+
+    struct ggml_map_custom1_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom2
+
+static void ggml_compute_forward_map_custom2(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+
+    struct ggml_map_custom2_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_map_custom3
+
+static void ggml_compute_forward_map_custom3(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * a = dst->src[0];
+    const struct ggml_tensor * b = dst->src[1];
+    const struct ggml_tensor * c = dst->src[2];
+
+    struct ggml_map_custom3_op_params p;
+    memcpy(&p, dst->op_params, sizeof(p));
+
+    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
+}
+
+// ggml_compute_forward_cross_entropy_loss
+
+static void ggml_compute_forward_cross_entropy_loss_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    float * sums =  (float *) params->wdata;
+    float * st   = ((float *) params->wdata) + nth + ith*nc;
+    float sum_thread = 0.0f;
+
+    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
+        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
+        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, s0);
+        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
+        assert(sum_softmax >= 0.0);
+
+        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
+        ggml_vec_mul_f32(nc, st, st, s1);
+
+        float sum_st = 0.0f;
+        ggml_vec_sum_f32(nc, &sum_st, st);
+        sum_thread += sum_st;
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            assert(!isnan(st[i]));
+            assert(!isinf(st[i]));
+        }
+#endif
+    }
+    sums[ith] = sum_thread;
+    ggml_barrier(params->threadpool);
+
+    if (ith == 0) {
+        float * dp = (float *) dst->data;
+        ggml_vec_sum_f32(nth, dp, sums);
+        dp[0] *= -1.0f / (float) nr;
+    }
+}
+
+static void ggml_compute_forward_cross_entropy_loss(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cross_entropy_loss_back
+
+static void ggml_compute_forward_cross_entropy_loss_back_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * opt0 = dst->src[2];
+
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+    const int64_t ith = params->ith;
+    const int64_t nth = params->nth;
+
+    // TODO: handle transposed/permuted matrices
+    const int64_t nc = src0->ne[0];
+    const int64_t nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
+
+    for (int64_t i1 = ir0; i1 < ir1; i1++) {
+        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
+        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            //printf("p[%d] = %f\n", i, p[i]);
+            assert(!isnan(s0[i]));
+            assert(!isnan(s1[i]));
+        }
+#endif
+
+        // soft_max
+        float max = -INFINITY;
+        ggml_vec_max_f32(nc, &max, s0);
+        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+        assert(sum > 0.0);
+        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
+
+        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+        ggml_vec_sub_f32(nc, ds0, ds0, s1);
+        ggml_vec_scale_f32(nc, ds0, d_by_nr);
+
+#ifndef NDEBUG
+        for (int64_t i = 0; i < nc; ++i) {
+            assert(!isnan(ds0[i]));
+            assert(!isinf(ds0[i]));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_cross_entropy_loss_back(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static void ggml_compute_forward_opt_step_adamw_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0         = dst->src[0];
+    const struct ggml_tensor * src0_grad    = dst->src[1];
+    const struct ggml_tensor * src0_grad_m  = dst->src[2];
+    const struct ggml_tensor * src0_grad_v  = dst->src[3];
+    const struct ggml_tensor * adamw_params = dst->src[4];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
+    const float alpha  = adamw_params_ptr[0];
+    const float beta1  = adamw_params_ptr[1];
+    const float beta2  = adamw_params_ptr[2];
+    const float eps    = adamw_params_ptr[3];
+    const float wd     = adamw_params_ptr[4];
+    const float beta1h = adamw_params_ptr[5];
+    const float beta2h = adamw_params_ptr[6];
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
+
+        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
+        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
+        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
+        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
+
+        for (int i00 = 0; i00 < ne00; ++i00) {
+            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
+            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
+
+            const float mh =       m[i00]*beta1h;
+            const float vh = sqrtf(v[i00]*beta2h) + eps;
+
+            // The weight decay is applied independently of the Adam momenta m and v.
+            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
+            // See: https://arxiv.org/pdf/1711.05101v3.pdf
+            w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
+        }
+    }
+}
+
+static void ggml_compute_forward_opt_step_adamw(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_opt_step_adamw_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+/////////////////////////////////
+
+static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
+    GGML_ASSERT(params);
+
+    if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
+        return;
+    }
+
+    // extra_buffer op?
+    if (ggml_cpu_extra_compute_forward(params, tensor)) return;
+
+    switch (tensor->op) {
+        case GGML_OP_DUP:
+            {
+                ggml_compute_forward_dup(params, tensor);
+            } break;
+        case GGML_OP_ADD:
+            {
+                ggml_compute_forward_add(params, tensor);
+            } break;
+        case GGML_OP_ADD1:
+            {
+                ggml_compute_forward_add1(params, tensor);
+            } break;
+        case GGML_OP_ACC:
+            {
+                ggml_compute_forward_acc(params, tensor);
+            } break;
+        case GGML_OP_SUB:
+            {
+                ggml_compute_forward_sub(params, tensor);
+            } break;
+        case GGML_OP_MUL:
+            {
+                ggml_compute_forward_mul(params, tensor);
+            } break;
+        case GGML_OP_DIV:
+            {
+                ggml_compute_forward_div(params, tensor);
+            } break;
+        case GGML_OP_SQR:
+            {
+                ggml_compute_forward_sqr(params, tensor);
+            } break;
+        case GGML_OP_SQRT:
+            {
+                ggml_compute_forward_sqrt(params, tensor);
+            } break;
+        case GGML_OP_LOG:
+            {
+                ggml_compute_forward_log(params, tensor);
+            } break;
+        case GGML_OP_SIN:
+            {
+                ggml_compute_forward_sin(params, tensor);
+            } break;
+        case GGML_OP_COS:
+            {
+                ggml_compute_forward_cos(params, tensor);
+            } break;
+        case GGML_OP_SUM:
+            {
+                ggml_compute_forward_sum(params, tensor);
+            } break;
+        case GGML_OP_SUM_ROWS:
+            {
+                ggml_compute_forward_sum_rows(params, tensor);
+            } break;
+        case GGML_OP_MEAN:
+            {
+                ggml_compute_forward_mean(params, tensor);
+            } break;
+        case GGML_OP_ARGMAX:
+            {
+                ggml_compute_forward_argmax(params, tensor);
+            } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                ggml_compute_forward_count_equal(params, tensor);
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                ggml_compute_forward_repeat(params, tensor);
+            } break;
+        case GGML_OP_REPEAT_BACK:
+            {
+                ggml_compute_forward_repeat_back(params, tensor);
+            } break;
+        case GGML_OP_CONCAT:
+            {
+                ggml_compute_forward_concat(params, tensor);
+            } break;
+        case GGML_OP_SILU_BACK:
+            {
+                ggml_compute_forward_silu_back(params, tensor);
+            } break;
+        case GGML_OP_NORM:
+            {
+                ggml_compute_forward_norm(params, tensor);
+            } break;
+        case GGML_OP_RMS_NORM:
+            {
+                ggml_compute_forward_rms_norm(params, tensor);
+            } break;
+        case GGML_OP_RMS_NORM_BACK:
+            {
+                ggml_compute_forward_rms_norm_back(params, tensor);
+            } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                ggml_compute_forward_group_norm(params, tensor);
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                ggml_compute_forward_mul_mat(params, tensor);
+            } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                ggml_compute_forward_mul_mat_id(params, tensor);
+            } break;
+        case GGML_OP_OUT_PROD:
+            {
+                ggml_compute_forward_out_prod(params, tensor);
+            } break;
+        case GGML_OP_SCALE:
+            {
+                ggml_compute_forward_scale(params, tensor);
+            } break;
+        case GGML_OP_SET:
+            {
+                ggml_compute_forward_set(params, tensor);
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_compute_forward_cpy(params, tensor);
+            } break;
+        case GGML_OP_CONT:
+            {
+                ggml_compute_forward_cont(params, tensor);
+            } break;
+        case GGML_OP_RESHAPE:
+            {
+                ggml_compute_forward_reshape(params, tensor);
+            } break;
+        case GGML_OP_VIEW:
+            {
+                ggml_compute_forward_view(params, tensor);
+            } break;
+        case GGML_OP_PERMUTE:
+            {
+                ggml_compute_forward_permute(params, tensor);
+            } break;
+        case GGML_OP_TRANSPOSE:
+            {
+                ggml_compute_forward_transpose(params, tensor);
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                ggml_compute_forward_get_rows(params, tensor);
+            } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                ggml_compute_forward_get_rows_back(params, tensor);
+            } break;
+        case GGML_OP_DIAG:
+            {
+                ggml_compute_forward_diag(params, tensor);
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                ggml_compute_forward_diag_mask_inf(params, tensor);
+            } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+            {
+                ggml_compute_forward_diag_mask_zero(params, tensor);
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                ggml_compute_forward_soft_max(params, tensor);
+            } break;
+        case GGML_OP_SOFT_MAX_BACK:
+            {
+                ggml_compute_forward_soft_max_back(params, tensor);
+            } break;
+        case GGML_OP_ROPE:
+            {
+                ggml_compute_forward_rope(params, tensor);
+            } break;
+        case GGML_OP_ROPE_BACK:
+            {
+                ggml_compute_forward_rope_back(params, tensor);
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                ggml_compute_forward_clamp(params, tensor);
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                ggml_compute_forward_conv_transpose_1d(params, tensor);
+            } break;
+        case GGML_OP_IM2COL:
+            {
+                ggml_compute_forward_im2col(params, tensor);
+            } break;
+        case GGML_OP_IM2COL_BACK:
+            {
+                ggml_compute_forward_im2col_back_f32(params, tensor);
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                ggml_compute_forward_conv_transpose_2d(params, tensor);
+            } break;
+        case GGML_OP_POOL_1D:
+            {
+                ggml_compute_forward_pool_1d(params, tensor);
+            } break;
+        case GGML_OP_POOL_2D:
+            {
+                ggml_compute_forward_pool_2d(params, tensor);
+            } break;
+        case GGML_OP_POOL_2D_BACK:
+            {
+                ggml_compute_forward_pool_2d_back(params, tensor);
+            } break;
+        case GGML_OP_UPSCALE:
+            {
+                ggml_compute_forward_upscale(params, tensor);
+            } break;
+        case GGML_OP_PAD:
+            {
+                ggml_compute_forward_pad(params, tensor);
+            } break;
+        case GGML_OP_PAD_REFLECT_1D:
+            {
+                ggml_compute_forward_pad_reflect_1d(params, tensor);
+            } break;
+        case GGML_OP_UNPAD:
+            {
+                ggml_compute_forward_unpad(params, tensor);
+            } break;
+        case GGML_OP_ARANGE:
+            {
+                ggml_compute_forward_arange(params, tensor);
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                ggml_compute_forward_timestep_embedding(params, tensor);
+            } break;
+        case GGML_OP_ARGSORT:
+            {
+                ggml_compute_forward_argsort(params, tensor);
+            } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                ggml_compute_forward_leaky_relu(params, tensor);
+            } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+            } break;
+        case GGML_OP_FLASH_ATTN_BACK:
+            {
+                int32_t t = ggml_get_op_params_i32(tensor, 0);
+                GGML_ASSERT(t == 0 || t == 1);
+                bool masked = t != 0;
+                ggml_compute_forward_flash_attn_back(params, masked, tensor);
+            } break;
+        case GGML_OP_SSM_CONV:
+            {
+                ggml_compute_forward_ssm_conv(params, tensor);
+            } break;
+        case GGML_OP_SSM_SCAN:
+            {
+                ggml_compute_forward_ssm_scan(params, tensor);
+            } break;
+        case GGML_OP_WIN_PART:
+            {
+                ggml_compute_forward_win_part(params, tensor);
+            } break;
+        case GGML_OP_WIN_UNPART:
+            {
+                ggml_compute_forward_win_unpart(params, tensor);
+            } break;
+        case GGML_OP_UNARY:
+            {
+                ggml_compute_forward_unary(params, tensor);
+            } break;
+        case GGML_OP_GET_REL_POS:
+            {
+                ggml_compute_forward_get_rel_pos(params, tensor);
+            } break;
+        case GGML_OP_ADD_REL_POS:
+            {
+                ggml_compute_forward_add_rel_pos(params, tensor);
+            } break;
+        case GGML_OP_RWKV_WKV6:
+            {
+                ggml_compute_forward_rwkv_wkv6(params, tensor);
+            } break;
+        case GGML_OP_MAP_UNARY:
+            {
+                ggml_unary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_unary(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_BINARY:
+            {
+                ggml_binary_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_binary(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM1_F32:
+            {
+                ggml_custom1_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom1_f32(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM2_F32:
+            {
+                ggml_custom2_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom2_f32(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM3_F32:
+            {
+                ggml_custom3_op_f32_t fun;
+                memcpy(&fun, tensor->op_params, sizeof(fun));
+                ggml_compute_forward_map_custom3_f32(params, tensor, fun);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                ggml_compute_forward_map_custom1(params, tensor);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                ggml_compute_forward_map_custom2(params, tensor);
+            }
+            break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                ggml_compute_forward_map_custom3(params, tensor);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            {
+                ggml_compute_forward_cross_entropy_loss(params, tensor);
+            }
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            {
+                ggml_compute_forward_cross_entropy_loss_back(params, tensor);
+            }
+            break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                ggml_compute_forward_opt_step_adamw(params, tensor);
+            }
+            break;
+        case GGML_OP_NONE:
+            {
+                // nop
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// Android's libc implementation "bionic" does not support setting affinity
+#if defined(__gnu_linux__)
+static void set_numa_thread_affinity(int thread_n) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    int node_num;
+    int rv;
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    switch(g_state.numa.numa_strategy) {
+        case GGML_NUMA_STRATEGY_DISTRIBUTE:
+            // run thread on node_num thread_n / (threads per node)
+            node_num = thread_n % g_state.numa.n_nodes;
+            break;
+        case GGML_NUMA_STRATEGY_ISOLATE:
+            // run thread on current_node
+            node_num = g_state.numa.current_node;
+            break;
+        case GGML_NUMA_STRATEGY_NUMACTL:
+            // use the cpuset that numactl gave us
+            rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset);
+            if (rv) {
+                fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv));
+            }
+            return;
+        default:
+            return;
+    }
+
+    struct ggml_numa_node * node = &g_state.numa.nodes[node_num];
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (size_t i = 0; i < node->n_cpus; ++i) {
+        CPU_SET_S(node->cpus[i], setsize, cpus);
+    }
+
+    rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+            fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+
+static void clear_numa_thread_affinity(void) {
+    if (!ggml_is_numa()) {
+        return;
+    }
+
+    size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus);
+
+    cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus);
+    CPU_ZERO_S(setsize, cpus);
+    for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) {
+        CPU_SET_S(i, setsize, cpus);
+    }
+
+    int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus);
+    if (rv) {
+        fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv));
+    }
+
+    CPU_FREE(cpus);
+}
+#else
+// TODO: Windows etc.
+// (the linux implementation may also work on BSD, someone should test)
+static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n);  }
+static void clear_numa_thread_affinity(void) {}
+#endif
+
+static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+    int n_tasks = 0;
+
+    if (ggml_is_empty(node)) {
+        // no need to multi-thread a no-op
+        n_tasks = 1;
+        return n_tasks;
+    }
+
+    switch (node->op) {
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_ACC:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_SUB:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_LOG:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_SUM:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
+        case GGML_OP_ARGMAX:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_REPEAT:
+        case GGML_OP_REPEAT_BACK:
+        case GGML_OP_LEAKY_RELU:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(node)) {
+                case GGML_UNARY_OP_ABS:
+                case GGML_UNARY_OP_SGN:
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_EXP:
+                    {
+                        n_tasks = 1;
+                    } break;
+
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
+                    {
+                        n_tasks = n_threads;
+                    } break;
+                default:
+                    GGML_ABORT("fatal error");
+            }
+            break;
+        case GGML_OP_SILU_BACK:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_CONCAT:
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+        case GGML_OP_OUT_PROD:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                // FIXME: get_rows can use additional threads, but the cost of launching additional threads
+                // decreases performance with GPU offloading
+                //n_tasks = n_threads;
+                n_tasks = 1;
+            } break;
+        case GGML_OP_SCALE:
+        case GGML_OP_SET:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_GET_ROWS_BACK:
+        case GGML_OP_DIAG:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_DIAG_MASK_ZERO:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX_BACK:
+        case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+        case GGML_OP_ADD_REL_POS:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                n_tasks = 1; //TODO
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
+            } break;
+        case GGML_OP_IM2COL:
+        case GGML_OP_IM2COL_BACK:
+        case GGML_OP_CONV_TRANSPOSE_1D:
+        case GGML_OP_CONV_TRANSPOSE_2D:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_POOL_1D:
+        case GGML_OP_POOL_2D:
+        case GGML_OP_POOL_2D_BACK:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_PAD_REFLECT_1D:
+        case GGML_OP_UNPAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_FLASH_ATTN_EXT:
+        case GGML_OP_FLASH_ATTN_BACK:
+        case GGML_OP_SSM_CONV:
+        case GGML_OP_SSM_SCAN:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_WIN_PART:
+        case GGML_OP_WIN_UNPART:
+        case GGML_OP_GET_REL_POS:
+        case GGML_OP_RWKV_WKV6:
+        case GGML_OP_MAP_UNARY:
+        case GGML_OP_MAP_BINARY:
+        case GGML_OP_MAP_CUSTOM1_F32:
+        case GGML_OP_MAP_CUSTOM2_F32:
+        case GGML_OP_MAP_CUSTOM3_F32:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_MAP_CUSTOM1:
+            {
+                struct ggml_map_custom1_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_MAP_CUSTOM2:
+            {
+                struct ggml_map_custom2_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_MAP_CUSTOM3:
+            {
+                struct ggml_map_custom3_op_params p;
+                memcpy(&p, node->op_params, sizeof(p));
+                if (p.n_tasks == GGML_N_TASKS_MAX) {
+                    n_tasks = n_threads;
+                } else {
+                    n_tasks = MIN(p.n_tasks, n_threads);
+                }
+            } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
+            {
+                n_tasks = n_threads;
+            } break;
+        case GGML_OP_NONE:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT:
+            {
+                GGML_ABORT("fatal error");
+            }
+        default:
+            {
+                fprintf(stderr, "%s: op not implemented: ", __func__);
+                if (node->op < GGML_OP_COUNT) {
+                    fprintf(stderr, "%s\n", ggml_op_name(node->op));
+                } else {
+                    fprintf(stderr, "%d\n", node->op);
+                }
+                GGML_ABORT("fatal error");
+            }
+    }
+
+    assert(n_tasks > 0);
+
+    return n_tasks;
+}
+
+static thread_ret_t ggml_graph_compute_secondary_thread(void* data);
+
+#if defined(_WIN32)
+#include "windows.h"
+
+// TODO: support > 64 CPUs
+static bool ggml_thread_apply_affinity(bool * mask) {
+    HANDLE    h = GetCurrentThread();
+    uint64_t  bitmask = 0ULL;
+
+    assert(GGML_MAX_N_THREADS >= 64);
+
+    for (int32_t i = 0; i < 8; i++) {
+        int32_t idx = i * 8;
+        uint8_t val = 0;
+        val |= mask[idx + 0] << 0;
+        val |= mask[idx + 1] << 1;
+        val |= mask[idx + 2] << 2;
+        val |= mask[idx + 3] << 3;
+        val |= mask[idx + 4] << 4;
+        val |= mask[idx + 5] << 5;
+        val |= mask[idx + 6] << 6;
+        val |= mask[idx + 7] << 7;
+        bitmask |= (uint64_t)val << idx;
+    }
+
+    for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) {
+        if (mask[i]) {
+            fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n");
+            break;
+        }
+    }
+
+    DWORD_PTR m = (DWORD_PTR)bitmask;
+
+    m = SetThreadAffinityMask(h, m);
+
+    return m != 0;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    // Note that on Windows the Process Priority Class must be updated in order to set Thread priority.
+    // This is up to the applications.
+    DWORD p = THREAD_PRIORITY_NORMAL;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   p = THREAD_PRIORITY_NORMAL;        break;
+        case GGML_SCHED_PRIO_MEDIUM:   p = THREAD_PRIORITY_ABOVE_NORMAL;  break;
+        case GGML_SCHED_PRIO_HIGH:     p = THREAD_PRIORITY_HIGHEST;       break;
+        case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
+    }
+
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        // Keep inherited policy/priority
+        return true;
+    }
+
+    if (!SetThreadPriority(GetCurrentThread(), p)) {
+        fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError());
+        return false;
+    }
+
+    return true;
+}
+
+#elif defined(__APPLE__)
+#include 
+#include 
+
+static bool ggml_thread_apply_affinity(const bool * mask) {
+    // Not supported on Apple platforms
+    UNUSED(mask);
+    return true;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    struct sched_param p;
+    int32_t policy = SCHED_OTHER;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;
+        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;
+        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;
+        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;
+    }
+
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        // Keep inherited policy/priority
+        return true;
+    }
+
+    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
+    if (err != 0) {
+        fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
+        return false;
+    }
+
+    return true;
+}
+
+#elif defined(__gnu_linux__)
+// TODO: this may not work on BSD, to be verified
+
+static bool ggml_thread_apply_affinity(const bool * mask) {
+    cpu_set_t cpuset;
+    int err;
+
+    CPU_ZERO(&cpuset);
+
+    for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
+        if (mask[i]) {
+            GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i);
+            CPU_SET(i, &cpuset);
+        }
+    }
+
+#ifdef __ANDROID__
+    err = sched_setaffinity(0, sizeof(cpuset), &cpuset);
+    if (err < 0) {
+        err = errno;
+    }
+#else
+    err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset);
+#endif
+    if (err != 0) {
+        fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err);
+        return false;
+    }
+
+    return true;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    struct sched_param p;
+    int32_t policy = SCHED_OTHER;
+    switch (prio) {
+        case GGML_SCHED_PRIO_NORMAL:   policy = SCHED_OTHER; p.sched_priority = 0;  break;
+        case GGML_SCHED_PRIO_MEDIUM:   policy = SCHED_FIFO;  p.sched_priority = 40; break;
+        case GGML_SCHED_PRIO_HIGH:     policy = SCHED_FIFO;  p.sched_priority = 80; break;
+        case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO;  p.sched_priority = 90; break;
+    }
+
+    if (prio == GGML_SCHED_PRIO_NORMAL) {
+        // Keep inherited policy/priority
+        return true;
+    }
+
+    int32_t err = pthread_setschedparam(pthread_self(), policy, &p);
+    if (err != 0) {
+        fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err);
+        return false;
+    }
+
+    return true;
+}
+
+#else // unsupported platforms
+
+static bool ggml_thread_apply_affinity(const bool * mask) {
+    UNUSED(mask);
+    return true;
+}
+
+static bool ggml_thread_apply_priority(int32_t prio) {
+    UNUSED(prio);
+    return true;
+}
+
+#endif
+
+static bool ggml_thread_cpumask_is_valid(const bool * mask) {
+    for (int i = 0; i < GGML_MAX_N_THREADS; i++) {
+        if (mask[i]) { return true; }
+    }
+    return false;
+}
+
+static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) {
+    if (!strict) {
+        memcpy(local_mask, global_mask, GGML_MAX_N_THREADS);
+        return;
+    } else {
+        memset(local_mask, 0, GGML_MAX_N_THREADS);
+        int32_t base_idx = *iter;
+        for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) {
+            int32_t idx = base_idx + i;
+            if (idx >= GGML_MAX_N_THREADS) {
+                // Just a cheaper modulo
+                idx -= GGML_MAX_N_THREADS;
+            }
+            if (global_mask[idx]) {
+                local_mask[idx] = 1;
+                *iter = idx + 1;
+                return;
+            }
+        }
+    }
+}
+
+void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
+    if (!threadpool) return;
+
+    const int n_threads = threadpool->n_threads_max;
+
+#ifndef GGML_USE_OPENMP
+    struct ggml_compute_state* workers = threadpool->workers;
+
+    ggml_mutex_lock(&threadpool->mutex);
+
+    threadpool->stop = true;
+    threadpool->pause = false;
+
+    ggml_cond_broadcast(&threadpool->cond);
+    ggml_mutex_unlock(&threadpool->mutex);
+
+    for (int j = 1; j < n_threads; j++) {
+        int32_t rc = ggml_thread_join(workers[j].thrd, NULL);
+        GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED);
+        UNUSED(rc);
+    }
+
+    ggml_mutex_destroy(&threadpool->mutex);
+    ggml_cond_destroy(&threadpool->cond);
+#endif // GGML_USE_OPENMP
+
+    const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
+    ggml_aligned_free(threadpool->workers, workers_size);
+    ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
+}
+
+#ifndef GGML_USE_OPENMP
+// pause/resume must be called under mutex
+static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) {
+    GGML_PRINT_DEBUG("Pausing threadpool\n");
+    threadpool->pause = true;
+    ggml_cond_broadcast(&threadpool->cond);
+}
+
+static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) {
+    GGML_PRINT_DEBUG("Resuming threadpool\n");
+    threadpool->pause = false;
+    ggml_cond_broadcast(&threadpool->cond);
+}
+#endif
+
+void ggml_threadpool_pause(struct ggml_threadpool * threadpool) {
+#ifndef GGML_USE_OPENMP
+    ggml_mutex_lock(&threadpool->mutex);
+    if (!threadpool->pause) {
+       ggml_threadpool_pause_locked(threadpool);
+    }
+    ggml_mutex_unlock(&threadpool->mutex);
+#else
+    UNUSED(threadpool);
+#endif
+}
+
+void ggml_threadpool_resume(struct ggml_threadpool * threadpool) {
+#ifndef GGML_USE_OPENMP
+    ggml_mutex_lock(&threadpool->mutex);
+    if (threadpool->pause) {
+       ggml_threadpool_resume_locked(threadpool);
+    }
+    ggml_mutex_unlock(&threadpool->mutex);
+#else
+    UNUSED(threadpool);
+#endif
+}
+
+struct ggml_cplan ggml_graph_plan(
+          const struct ggml_cgraph * cgraph,
+                               int   n_threads,
+            struct ggml_threadpool * threadpool) {
+
+    if (threadpool == NULL) {
+        //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
+    }
+    if (n_threads <= 0) {
+        n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS;
+    }
+
+    size_t work_size = 0;
+
+    struct ggml_cplan cplan;
+    memset(&cplan, 0, sizeof(struct ggml_cplan));
+
+    int max_tasks = 1;
+
+    // thread scheduling for the different operations + work buffer size estimation
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        const int n_tasks = ggml_get_n_tasks(node, n_threads);
+
+        max_tasks = MAX(max_tasks, n_tasks);
+
+        size_t cur = 0;
+
+        if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
+
+            switch (node->op) {
+                case GGML_OP_CPY:
+                case GGML_OP_DUP:
+                    {
+                        if (ggml_is_quantized(node->type) ||
+                            // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
+                            (node->src[0]->type == GGML_TYPE_F16  && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
+                            (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
+                            cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                        }
+                    } break;
+                case GGML_OP_ADD:
+                case GGML_OP_ADD1:
+                    {
+                        if (ggml_is_quantized(node->src[0]->type)) {
+                            cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+                        }
+                    } break;
+                case GGML_OP_ACC:
+                    {
+                        if (ggml_is_quantized(node->src[0]->type)) {
+                            cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
+                        }
+                    } break;
+                case GGML_OP_COUNT_EQUAL:
+                    {
+                        cur = ggml_type_size(node->type)*n_tasks;
+                    } break;
+                case GGML_OP_MUL_MAT:
+                    {
+                        const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
+
+                        if (node->src[1]->type != vec_dot_type) {
+                            cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
+                        }
+                    } break;
+                case GGML_OP_MUL_MAT_ID:
+                    {
+                        cur = 0;
+                        const struct ggml_tensor * src0 = node->src[0];
+                        const struct ggml_tensor * src1 = node->src[1];
+                        const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
+                        if (src1->type != vec_dot_type) {
+                            cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
+                        }
+                        const int n_as = src0->ne[2];
+                        cur += GGML_PAD(cur, sizeof(int64_t));       // align
+                        cur += n_as * sizeof(int64_t);               // matrix_row_counts
+                        cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
+                    } break;
+                case GGML_OP_OUT_PROD:
+                    {
+                        if (ggml_is_quantized(node->src[0]->type)) {
+                            cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
+                        }
+                    } break;
+                case GGML_OP_SOFT_MAX:
+                case GGML_OP_ROPE:
+                    {
+                        cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                    } break;
+                case GGML_OP_CONV_TRANSPOSE_1D:
+                    {
+                        GGML_ASSERT(node->src[0]->ne[3] == 1);
+                        GGML_ASSERT(node->src[1]->ne[2] == 1);
+                        GGML_ASSERT(node->src[1]->ne[3] == 1);
+
+                        const int64_t ne00 = node->src[0]->ne[0];  // K
+                        const int64_t ne01 = node->src[0]->ne[1];  // Cout
+                        const int64_t ne02 = node->src[0]->ne[2];  // Cin
+                        const int64_t ne10 = node->src[1]->ne[0];  // L
+                        const int64_t ne11 = node->src[1]->ne[1];  // Cin
+
+                        if ((node->src[0]->type == GGML_TYPE_F16 ||
+                             node->src[0]->type == GGML_TYPE_BF16) &&
+                            node->src[1]->type == GGML_TYPE_F32) {
+                            cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02;
+                            cur += sizeof(ggml_fp16_t)*ne10*ne11;
+                        } else if (node->src[0]->type == GGML_TYPE_F32 &&
+                                   node->src[1]->type == GGML_TYPE_F32) {
+                            cur += sizeof(float)*ne00*ne01*ne02;
+                            cur += sizeof(float)*ne10*ne11;
+                        } else {
+                            GGML_ABORT("fatal error");
+                        }
+                    } break;
+                case GGML_OP_CONV_TRANSPOSE_2D:
+                    {
+                        const int64_t ne00 = node->src[0]->ne[0]; // W
+                        const int64_t ne01 = node->src[0]->ne[1]; // H
+                        const int64_t ne02 = node->src[0]->ne[2]; // Channels Out
+                        const int64_t ne03 = node->src[0]->ne[3]; // Channels In
+
+                        const int64_t ne10 = node->src[1]->ne[0]; // W
+                        const int64_t ne11 = node->src[1]->ne[1]; // H
+                        const int64_t ne12 = node->src[1]->ne[2]; // Channels In
+
+                        cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03;
+                        cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12;
+                    } break;
+                case GGML_OP_FLASH_ATTN_EXT:
+                    {
+                        const int64_t ne00 = node->src[0]->ne[0]; // D
+
+                        cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
+                    } break;
+                case GGML_OP_FLASH_ATTN_BACK:
+                    {
+                        const int64_t    D = node->src[0]->ne[0];
+                        const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
+                        const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+                        if (node->src[1]->type == GGML_TYPE_F32) {
+                            cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                        } else if (node->src[1]->type == GGML_TYPE_F16) {
+                            cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                        } else if (node->src[1]->type == GGML_TYPE_BF16) {
+                            cur  = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+                            cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+                        }
+                    } break;
+
+                case GGML_OP_CROSS_ENTROPY_LOSS:
+                    {
+                        cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
+                    } break;
+                case GGML_OP_COUNT:
+                    {
+                        GGML_ABORT("fatal error");
+                    }
+                default:
+                    break;
+            }
+        }
+
+        work_size = MAX(work_size, cur);
+    }
+
+    if (work_size > 0) {
+        work_size += CACHE_LINE_SIZE*(n_threads);
+    }
+
+    cplan.threadpool = threadpool;
+    cplan.n_threads  = MIN(max_tasks, n_threads);
+    cplan.work_size  = work_size;
+    cplan.work_data  = NULL;
+
+    return cplan;
+}
+
+static thread_ret_t ggml_graph_compute_thread(void * data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+    struct ggml_threadpool    * tp    = state->threadpool;
+
+    const struct ggml_cgraph * cgraph = tp->cgraph;
+    const struct ggml_cplan  * cplan  = tp->cplan;
+
+    set_numa_thread_affinity(state->ith);
+
+    struct ggml_compute_params params = {
+        /*.ith       =*/ state->ith,
+        /*.nth       =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed),
+        /*.wsize     =*/ cplan->work_size,
+        /*.wdata     =*/ cplan->work_data,
+        /*.threadpool=*/ tp,
+    };
+
+    for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
+        struct ggml_tensor * node = cgraph->nodes[node_n];
+
+        ggml_compute_forward(¶ms, node);
+
+        if (state->ith == 0 && cplan->abort_callback &&
+                cplan->abort_callback(cplan->abort_callback_data)) {
+            tp->abort = true;
+            tp->ec    = GGML_STATUS_ABORTED;
+        }
+
+        ggml_barrier(state->threadpool);
+    }
+
+    return 0;
+}
+
+#ifndef GGML_USE_OPENMP
+
+// check if thread is active
+static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+    int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed);
+    return (state->ith < n_threads);
+}
+
+// check if thread is ready to proceed (exit from polling or sleeping)
+static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    if (state->pending || threadpool->stop || threadpool->pause) { return true; }
+
+    // check for new graph/work
+    int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed);
+    if (new_graph != state->last_graph) {
+        state->pending    = ggml_graph_compute_thread_active(state);
+        state->last_graph = new_graph;
+    }
+
+    return state->pending;
+}
+
+// sync thread state after polling
+static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) {
+    // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead
+    #ifdef GGML_TSAN_ENABLED
+    atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst);
+    #else
+    atomic_thread_fence(memory_order_seq_cst);
+    #endif
+    UNUSED(state);
+}
+
+static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    // Skip polling for unused threads
+    if (!ggml_graph_compute_thread_active(state)) {
+        return state->pending;
+    }
+
+    // This seems to make 0 ... 100 a decent range for polling level across modern processors.
+    // Perhaps, we can adjust it dynamically based on load and things.
+    const uint64_t n_rounds = 1024UL * 128 * threadpool->poll;
+
+    for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) {
+        // No new work. Keep polling.
+        ggml_thread_cpu_relax();
+    }
+
+    return state->pending;
+}
+
+static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) {
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    if (ggml_graph_compute_poll_for_work(state)) {
+        ggml_graph_compute_thread_sync(state);
+        return state->pending;
+    }
+
+    ggml_mutex_lock_shared(&threadpool->mutex);
+    while (!ggml_graph_compute_thread_ready(state)) {
+        // No new work. Wait for the signal.
+        GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith);
+        ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
+    }
+    ggml_mutex_unlock_shared(&threadpool->mutex);
+
+    return state->pending;
+}
+
+static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
+    struct ggml_compute_state * state = (struct ggml_compute_state *) data;
+    struct ggml_threadpool * threadpool = state->threadpool;
+
+    ggml_thread_apply_priority(threadpool->prio);
+    if (ggml_thread_cpumask_is_valid(state->cpumask)) {
+        ggml_thread_apply_affinity(state->cpumask);
+    }
+
+    while (true) {
+        // Check if we need to sleep
+        while (threadpool->pause) {
+            GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith);
+            ggml_mutex_lock_shared(&threadpool->mutex);
+            if (threadpool->pause) {
+                ggml_cond_wait(&threadpool->cond, &threadpool->mutex);
+            }
+            GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith);
+            ggml_mutex_unlock_shared(&threadpool->mutex);
+        }
+
+        // This needs to be checked for after the cond_wait
+        if (threadpool->stop) break;
+
+        // Check if there is new work
+        // The main thread is the only one that can dispatch new work
+
+        ggml_graph_compute_check_for_work(state);
+        if (state->pending) {
+            state->pending = false;
+
+            ggml_graph_compute_thread(state);
+        }
+    }
+
+    return (thread_ret_t) 0;
+}
+
+// Start processing new graph
+static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads)
+{
+    // Always take the mutex here because the worker threads are doing hybrid poll/wait
+
+    ggml_mutex_lock(&threadpool->mutex);
+
+    GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads);
+
+    // Update the number of active threads
+    atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
+
+    // Indicate the graph is ready to be processed
+    // We need the full seq-cst fence here because of the polling threads (used in thread_sync)
+    atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst);
+
+    if (threadpool->pause) {
+       // Update main thread prio and affinity to match the threadpool settings
+       ggml_thread_apply_priority(threadpool->prio);
+       if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
+           ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
+       }
+
+       // resume does cond broadcast
+       ggml_threadpool_resume_locked(threadpool);
+    } else {
+       ggml_cond_broadcast(&threadpool->cond);
+    }
+
+    ggml_mutex_unlock(&threadpool->mutex);
+}
+
+#endif // GGML_USE_OPENMP
+
+static struct ggml_threadpool * ggml_threadpool_new_impl(
+    struct ggml_threadpool_params * tpp,
+               struct ggml_cgraph * cgraph,
+                struct ggml_cplan * cplan) {
+
+    struct ggml_threadpool * threadpool =
+        ggml_aligned_malloc(sizeof(struct ggml_threadpool));
+    {
+        threadpool->cgraph           = cgraph;
+        threadpool->cplan            = cplan;
+        threadpool->n_graph          = 0;
+        threadpool->n_barrier        = 0;
+        threadpool->n_barrier_passed = 0;
+        threadpool->current_chunk    = 0;
+        threadpool->stop             = false;
+        threadpool->pause            = tpp->paused;
+        threadpool->abort            = false;
+        threadpool->workers          = NULL;
+        threadpool->n_threads_max    = tpp->n_threads;
+        threadpool->n_threads_cur    = tpp->n_threads;
+        threadpool->poll             = tpp->poll;
+        threadpool->prio             = tpp->prio;
+        threadpool->ec               = GGML_STATUS_SUCCESS;
+    }
+
+    // Allocate and init workers state
+    const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
+    struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
+
+    memset(workers, 0, workers_size);
+    for (int j = 0; j < tpp->n_threads; j++) {
+        workers[j].threadpool = threadpool;
+        workers[j].ith        = j;
+    }
+
+    threadpool->workers = workers;
+
+#ifndef GGML_USE_OPENMP
+    ggml_mutex_init(&threadpool->mutex);
+    ggml_cond_init(&threadpool->cond);
+
+    // Spin the threads for all workers, and update CPU placements.
+    // Place the main thread last (towards the higher numbered CPU cores).
+
+    int32_t cpumask_iter = 0;
+
+    for (int j = 1; j < tpp->n_threads; j++) {
+        ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter);
+
+        int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]);
+        GGML_ASSERT(rc == 0);
+    }
+
+    ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter);
+
+    if (!threadpool->pause) {
+        // Update main thread prio and affinity at the start, otherwise we'll do it in resume
+        ggml_thread_apply_priority(threadpool->prio);
+        if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) {
+            ggml_thread_apply_affinity(threadpool->workers[0].cpumask);
+        }
+    }
+#endif // GGML_USE_OPENMP
+
+    return threadpool;
+}
+
+struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) {
+    return ggml_threadpool_new_impl(tpp, NULL, NULL);
+}
+
+enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+    ggml_cpu_init();
+
+    GGML_ASSERT(cplan);
+    GGML_ASSERT(cplan->n_threads > 0);
+    GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL);
+
+    int n_threads                               = cplan->n_threads;
+    struct ggml_threadpool * threadpool = cplan->threadpool;
+
+    bool disposable_threadpool = false;
+
+    if (threadpool == NULL) {
+        //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads);
+        disposable_threadpool = true;
+
+        struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads);
+        threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan);
+    } else {
+        // Reset some of the parameters that need resetting
+        // No worker threads should be accessing the parameters below at this stage
+        threadpool->cgraph           = cgraph;
+        threadpool->cplan            = cplan;
+        threadpool->current_chunk    = 0;
+        threadpool->abort            = false;
+        threadpool->ec               = GGML_STATUS_SUCCESS;
+    }
+
+#ifdef GGML_USE_OPENMP
+    if (n_threads > 1) {
+        #pragma omp parallel num_threads(n_threads)
+        {
+            #pragma omp single
+            {
+                // update the number of threads from the actual number of threads that we got from OpenMP
+                n_threads = omp_get_num_threads();
+                atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed);
+            }
+
+            ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]);
+        }
+    } else {
+        atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed);
+        ggml_graph_compute_thread(&threadpool->workers[0]);
+    }
+#else
+    if (n_threads > threadpool->n_threads_max) {
+        GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
+        n_threads = threadpool->n_threads_max;
+    }
+
+    // Kick all threads to start the new graph
+    ggml_graph_compute_kickoff(threadpool, n_threads);
+
+    // This is a work thread too
+    ggml_graph_compute_thread(&threadpool->workers[0]);
+#endif
+
+    // don't leave affinity set on the main thread
+    clear_numa_thread_affinity();
+
+    enum ggml_status ret = threadpool->ec;
+
+    if (disposable_threadpool) {
+        ggml_threadpool_free(threadpool);
+    }
+
+    return ret;
+}
+
+enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL);
+
+    cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size);
+
+    return ggml_graph_compute(cgraph, &cplan);
+}
+
+
+int ggml_cpu_has_avx(void) {
+#if defined(__AVX__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx_vnni(void) {
+#if defined(__AVXVNNI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx2(void) {
+#if defined(__AVX2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512(void) {
+#if defined(__AVX512F__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vbmi(void) {
+#if defined(__AVX512VBMI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_vnni(void) {
+#if defined(__AVX512VNNI__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_avx512_bf16(void) {
+#if defined(__AVX512BF16__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_amx_int8(void) {
+#if defined(__AMX_INT8__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fma(void) {
+#if defined(__FMA__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_arm_fma(void) {
+#if defined(__ARM_FEATURE_FMA)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_riscv_v(void) {
+#if defined(__riscv_v_intrinsic)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_f16c(void) {
+#if defined(__F16C__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_fp16_va(void) {
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_wasm_simd(void) {
+#if defined(__wasm_simd128__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_llamafile(void) {
+#if defined(GGML_USE_LLAMAFILE)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_sse3(void) {
+#if defined(__SSE3__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_ssse3(void) {
+#if defined(__SSSE3__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_vsx(void) {
+#if defined(__POWER9_VECTOR__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_neon(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_NEON)
+    return ggml_arm_arch_features.has_neon;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_dotprod(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD)
+    return ggml_arm_arch_features.has_dotprod;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_sve(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
+    return ggml_arm_arch_features.has_sve;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_matmul_int8(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8)
+    return ggml_arm_arch_features.has_i8mm;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_get_sve_cnt(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE)
+    return ggml_arm_arch_features.sve_cnt;
+#else
+    return 0;
+#endif
+}
+
+void ggml_cpu_init(void) {
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL, false };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+
+    ggml_critical_section_start();
+
+    static bool is_first_call = true;
+
+    if (is_first_call) {
+        // initialize GELU, Quick GELU, SILU and EXP F32 tables
+        {
+            const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
+
+            for (int i = 0; i < (1 << 16); ++i) {
+                union {
+                    uint16_t u16;
+                    ggml_fp16_t fp16;
+                } u = {i};
+                float f = GGML_FP16_TO_FP32(u.fp16);
+                ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
+            }
+
+            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
+
+            GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
+        }
+
+#if defined(__ARM_ARCH)
+        ggml_init_arm_arch_features();
+#endif
+
+        is_first_call = false;
+    }
+
+    ggml_critical_section_end();
+}
diff --git a/llama/ggml-cpu.cpp b/llama/ggml-cpu.cpp
new file mode 100644
index 000000000..38395101f
--- /dev/null
+++ b/llama/ggml-cpu.cpp
@@ -0,0 +1,653 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-backend.h"
+#include "ggml-backend-impl.h"
+#include "ggml-cpu.h"
+#include "ggml-cpu-aarch64.h"
+#include "ggml-cpu-traits.h"
+#include "ggml-impl.h"
+#include "amx.h"
+#include 
+#include 
+#include 
+
+#ifdef GGML_USE_CPU_HBM
+#include "ggml-cpu-hbm.h"
+#endif
+
+#if defined(__APPLE__)
+#include 
+#include 
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include 
+#endif
+
+// ggml-backend interface
+
+std::vector& ggml_backend_cpu_get_extra_buffers_type() {
+    static std::vector bufts = []() {
+        std::vector bufts;
+
+#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
+        if (ggml_backend_amx_buffer_type()) {
+            bufts.push_back(ggml_backend_amx_buffer_type());
+        }
+#endif
+
+#ifdef GGML_USE_CPU_AARCH64
+        if (ggml_backend_cpu_aarch64_buffer_type()) {
+            bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
+        }
+#endif
+
+        bufts.push_back(NULL);
+
+        return bufts;
+    }();
+
+    return bufts;
+}
+
+static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
+    return ggml_backend_cpu_get_extra_buffers_type().data();
+
+    GGML_UNUSED(device);
+}
+
+static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
+    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
+        if (extra && extra == buft) return true;
+    }
+    return false;
+}
+
+// CPU backend - backend (stream)
+
+struct ggml_backend_cpu_context {
+    int                 n_threads;
+    ggml_threadpool_t   threadpool;
+
+    uint8_t *           work_data;
+    size_t              work_size;
+
+    ggml_abort_callback abort_callback;
+    void *              abort_callback_data;
+};
+
+static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
+    return "CPU";
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_cpu_free(ggml_backend_t backend) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+    delete[] cpu_ctx->work_data;
+    delete cpu_ctx;
+    delete backend;
+}
+
+struct ggml_backend_plan_cpu {
+    struct ggml_cplan cplan;
+    struct ggml_cgraph cgraph;
+};
+
+static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;
+
+    cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+    cpu_plan->cgraph = *cgraph; // FIXME: deep copy
+
+    if (cpu_plan->cplan.work_size > 0) {
+        cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];
+        if (cpu_plan->cplan.work_data == NULL) {
+            delete cpu_plan;
+            return NULL;
+        }
+    }
+
+    cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
+    cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+    return cpu_plan;
+}
+
+static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    delete[] cpu_plan->cplan.work_data;
+    delete cpu_plan;
+
+    GGML_UNUSED(backend);
+}
+
+static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
+    struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
+
+    return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
+
+    GGML_UNUSED(backend);
+}
+
+static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
+
+    struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
+
+    if (cpu_ctx->work_size < cplan.work_size) {
+        delete[] cpu_ctx->work_data;
+        cpu_ctx->work_data = new uint8_t[cplan.work_size];
+        if (cpu_ctx->work_data == NULL) {
+            cpu_ctx->work_size = 0;
+            return GGML_STATUS_ALLOC_FAILED;
+        }
+        cpu_ctx->work_size = cplan.work_size;
+    }
+    cplan.work_data = (uint8_t *)cpu_ctx->work_data;
+
+    cplan.abort_callback      = cpu_ctx->abort_callback;
+    cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+
+    return ggml_graph_compute(cgraph, &cplan);
+}
+
+static const struct ggml_backend_i ggml_backend_cpu_i = {
+    /* .get_name                = */ ggml_backend_cpu_get_name,
+    /* .free                    = */ ggml_backend_cpu_free,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ ggml_backend_cpu_graph_plan_create,
+    /* .graph_plan_free         = */ ggml_backend_cpu_graph_plan_free,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ ggml_backend_cpu_graph_plan_compute,
+    /* .graph_compute           = */ ggml_backend_cpu_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_cpu_guid(void) {
+    static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_cpu_init(void) {
+    // initialize CPU backend now to avoid slowing the first graph computation
+    ggml_cpu_init();
+
+    struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;
+    if (ctx == NULL) {
+        return NULL;
+    }
+
+    ctx->n_threads           = GGML_DEFAULT_N_THREADS;
+    ctx->threadpool          = NULL;
+    ctx->work_data           = NULL;
+    ctx->work_size           = 0;
+    ctx->abort_callback      = NULL;
+    ctx->abort_callback_data = NULL;
+
+    ggml_backend_t cpu_backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_cpu_guid(),
+        /* .interface = */ ggml_backend_cpu_i,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context   = */ ctx,
+    };
+
+    if (cpu_backend == NULL) {
+        delete ctx;
+        return NULL;
+    }
+
+    return cpu_backend;
+}
+
+bool ggml_backend_is_cpu(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
+}
+
+void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->n_threads = n_threads;
+}
+
+void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+
+    if (ctx->threadpool && ctx->threadpool != threadpool) {
+        // already had a different threadpool, pause/suspend it before switching
+        ggml_threadpool_pause(ctx->threadpool);
+    }
+    ctx->threadpool = threadpool;
+}
+
+void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->abort_callback = abort_callback;
+    ctx->abort_callback_data = abort_callback_data;
+}
+
+// CPU backend - device
+
+struct ggml_backend_cpu_device_context {
+    std::string description = "CPU";
+
+    ggml_backend_cpu_device_context() {
+#ifdef __APPLE__
+        size_t len = 0;
+        if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) {
+            description.resize(len);
+            sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT
+        }
+#elif defined(__linux__)
+        FILE * f = fopen("/proc/cpuinfo", "r");
+        if (f) {
+            char buf[1024];
+            while (fgets(buf, sizeof(buf), f)) {
+                if (strncmp(buf, "model name", 10) == 0) {
+                    char * p = strchr(buf, ':');
+                    if (p) {
+                        p++;
+                        while (std::isspace(*p)) {
+                            p++;
+                        }
+                        while (std::isspace(p[strlen(p) - 1])) {
+                            p[strlen(p) - 1] = '\0';
+                        }
+                        description = p;
+                        break;
+                    }
+                }
+            }
+            fclose(f);
+        }
+#elif defined(_WIN32)
+        HKEY hKey;
+        if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
+                        TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
+                        0,
+                        KEY_READ,
+                        &hKey) == ERROR_SUCCESS) {
+            DWORD cpu_brand_size = 0;
+            if (RegQueryValueExA(hKey,
+                                TEXT("ProcessorNameString"),
+                                NULL,
+                                NULL,
+                                NULL,
+                                &cpu_brand_size) == ERROR_SUCCESS) {
+                description.resize(cpu_brand_size);
+                if (RegQueryValueExA(hKey,
+                                    TEXT("ProcessorNameString"),
+                                    NULL,
+                                    NULL,
+                                    (LPBYTE)&description[0], // NOLINT
+                                    &cpu_brand_size) == ERROR_SUCCESS) {
+                    if (description.find('\0') != std::string::npos) {
+                        description.resize(description.find('\0'));
+                    }
+                }
+            }
+            RegCloseKey(hKey);
+        }
+#endif
+    }
+};
+
+static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {
+    return "CPU";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {
+    struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;
+
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // TODO
+    *free = 0;
+    *total = 0;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_CPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_cpu_device_get_name(dev);
+    props->description = ggml_backend_cpu_device_get_description(dev);
+    props->type        = ggml_backend_cpu_device_get_type(dev);
+    ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
+    return ggml_backend_cpu_init();
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    if (op->op == GGML_OP_NONE || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_VIEW || op->op == GGML_OP_PERMUTE || op->op == GGML_OP_TRANSPOSE) {
+        return true;
+    }
+
+    // extra_buffer_op?
+    for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
+        if (extra) {
+            auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
+            if (buf_extra && buf_extra->supports_op(dev, op)) {
+                return true;
+            }
+        }
+    }
+
+    // the other case need host buffer.
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
+            return false;
+        }
+    }
+
+    switch (op->op) {
+        case GGML_OP_CPY:
+            return
+                op->type != GGML_TYPE_IQ3_XXS &&
+                op->type != GGML_TYPE_IQ3_S   &&
+                op->type != GGML_TYPE_IQ2_XXS &&
+                op->type != GGML_TYPE_IQ2_XS  &&
+                op->type != GGML_TYPE_IQ2_S   &&
+                op->type != GGML_TYPE_IQ1_S   &&
+                op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
+        case GGML_OP_MUL_MAT:
+            return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
+        case GGML_OP_ROPE_BACK:
+            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
+        case GGML_OP_IM2COL_BACK:
+            return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
+        case GGML_OP_OUT_PROD:
+            return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
+        default:
+            return true;
+    }
+}
+
+static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_is_extra_buffer_type(buft);
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
+    /* .get_name             = */ ggml_backend_cpu_device_get_name,
+    /* .get_description      = */ ggml_backend_cpu_device_get_description,
+    /* .get_memory           = */ ggml_backend_cpu_device_get_memory,
+    /* .get_type             = */ ggml_backend_cpu_device_get_type,
+    /* .get_props            = */ ggml_backend_cpu_device_get_props,
+    /* .init_backend         = */ ggml_backend_cpu_device_init_backend,
+    /* .get_buffer_type      = */ ggml_backend_cpu_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
+    /* .supports_op          = */ ggml_backend_cpu_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_cpu_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// CPU backend - backend (reg)
+
+static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
+    return "CPU";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    static ggml_backend_cpu_device_context ctx;
+    static ggml_backend_device ggml_backend_cpu_device = {
+        /* .iface   = */ ggml_backend_cpu_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ &ctx,
+    };
+
+    return &ggml_backend_cpu_device;
+}
+
+// This is intended to replace the the ggml_cpu_has_* functions when loading the CPU backend dynamically,
+// and additionally to allow other backends to expose their own list of features that applications can query using the same API
+static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t reg) {
+    static std::vector features = []() {
+        ggml_cpu_init();
+
+        std::vector features;
+        if (ggml_cpu_has_sse3()) {
+            features.push_back({ "SSE3", "1" });
+        }
+        if (ggml_cpu_has_ssse3()) {
+            features.push_back({ "SSSE3", "1" });
+        }
+        if (ggml_cpu_has_avx()) {
+            features.push_back({ "AVX", "1" });
+        }
+        if (ggml_cpu_has_avx_vnni()) {
+            features.push_back({ "AVX_VNNI", "1" });
+        }
+        if (ggml_cpu_has_avx2()) {
+            features.push_back({ "AVX2", "1" });
+        }
+        if (ggml_cpu_has_f16c()) {
+            features.push_back({ "F16C", "1" });
+        }
+        if (ggml_cpu_has_fma()) {
+            features.push_back({ "FMA", "1" });
+        }
+        if (ggml_cpu_has_avx512()) {
+            features.push_back({ "AVX512", "1" });
+        }
+        if (ggml_cpu_has_avx512_vbmi()) {
+            features.push_back({ "AVX512_VBMI", "1" });
+        }
+        if (ggml_cpu_has_avx512_vnni()) {
+            features.push_back({ "AVX512_VNNI", "1" });
+        }
+        if (ggml_cpu_has_avx512_bf16()) {
+            features.push_back({ "AVX512_BF16", "1" });
+        }
+        if (ggml_cpu_has_amx_int8()) {
+            features.push_back({ "AMX_INT8", "1" });
+        }
+        if (ggml_cpu_has_neon()) {
+            features.push_back({ "NEON", "1" });
+        }
+        if (ggml_cpu_has_arm_fma()) {
+            features.push_back({ "ARM_FMA", "1" });
+        }
+        if (ggml_cpu_has_fp16_va()) {
+            features.push_back({ "FP16_VA", "1" });
+        }
+        if (ggml_cpu_has_matmul_int8()) {
+            features.push_back({ "MATMUL_INT8", "1" });
+        }
+        if (ggml_cpu_has_sve()) {
+            features.push_back({ "SVE", "1" });
+        }
+        if (ggml_cpu_has_dotprod()) {
+            features.push_back({ "DOTPROD", "1" });
+        }
+        if (ggml_cpu_has_matmul_int8()) {
+            features.push_back({ "MATMUL_INT8", "1" });
+        }
+        if (ggml_cpu_get_sve_cnt() > 0) {
+            static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
+            features.push_back({ "SVE_CNT", sve_cnt.c_str() });
+        }
+        if (ggml_cpu_has_riscv_v()) {
+            features.push_back({ "RISCV_V", "1" });
+        }
+        if (ggml_cpu_has_vsx()) {
+            features.push_back({ "VSX", "1" });
+        }
+        if (ggml_cpu_has_wasm_simd()) {
+            features.push_back({ "WASM_SIMD", "1" });
+        }
+        if (ggml_cpu_has_llamafile()) {
+            features.push_back({ "LLAMAFILE", "1" });
+        }
+    #ifdef GGML_USE_ACCELERATE
+        features.push_back({ "ACCELERATE", "1" });
+    #endif
+    #ifdef GGML_USE_CPU_HBM
+        features.push_back({ "CPU_HBM", "1" });
+    #endif
+    #ifdef GGML_USE_OPENMP
+        features.push_back({ "OPENMP", "1" });
+    #endif
+    #ifdef GGML_USE_CPU_AARCH64
+        features.push_back({ "AARCH64_REPACK", "1" });
+    #endif
+
+        features.push_back({ nullptr, nullptr });
+
+        return features;
+    }();
+
+    return features.data();
+
+    GGML_UNUSED(reg);
+}
+
+static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads;
+        return (void *)fct;
+    }
+    if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
+        ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type;
+        return (void *)fct;
+    }
+    if (strcmp(name, "ggml_backend_get_features") == 0) {
+        return (void *)ggml_backend_cpu_get_features;
+    }
+    if (strcmp(name, "ggml_backend_set_abort_callback") == 0) {
+        return (void *)ggml_backend_cpu_set_abort_callback;
+    }
+    if (strcmp(name, "ggml_backend_cpu_numa_init") == 0) {
+        return (void *)ggml_numa_init;
+    }
+    if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
+        return (void *)ggml_is_numa;
+    }
+
+    // threadpool - TODO:  move to ggml-base
+    if (strcmp(name, "ggml_threadpool_new") == 0) {
+        return (void *)ggml_threadpool_new;
+    }
+    if (strcmp(name, "ggml_threadpool_free") == 0) {
+        return (void *)ggml_threadpool_free;
+    }
+    if (strcmp(name, "ggml_backend_cpu_set_threadpool") == 0) {
+        return (void *)ggml_backend_cpu_set_threadpool;
+    }
+
+    return NULL;
+
+    GGML_UNUSED(reg);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
+    /* .get_name         = */ ggml_backend_cpu_reg_get_name,
+    /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_cpu_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_cpu_reg(void) {
+    // init CPU feature detection
+    ggml_cpu_init();
+
+    static struct ggml_backend_reg ggml_backend_cpu_reg = {
+        /* .api_version = */ GGML_BACKEND_API_VERSION,
+        /* .iface       = */ ggml_backend_cpu_reg_i,
+        /* .context     = */ NULL,
+    };
+
+    return &ggml_backend_cpu_reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_cpu_reg)
diff --git a/llama/ggml-cpu.h b/llama/ggml-cpu.h
new file mode 100644
index 000000000..c2b64e662
--- /dev/null
+++ b/llama/ggml-cpu.h
@@ -0,0 +1,161 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    // the compute plan that needs to be prepared for ggml_graph_compute()
+    // since https://github.com/ggerganov/ggml/issues/287
+    struct ggml_cplan {
+        size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+        uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+        int n_threads;
+        struct ggml_threadpool * threadpool;
+
+        // abort ggml_graph_compute when true
+        ggml_abort_callback abort_callback;
+        void *              abort_callback_data;
+    };
+
+    // numa strategies
+    enum ggml_numa_strategy {
+        GGML_NUMA_STRATEGY_DISABLED   = 0,
+        GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
+        GGML_NUMA_STRATEGY_ISOLATE    = 2,
+        GGML_NUMA_STRATEGY_NUMACTL    = 3,
+        GGML_NUMA_STRATEGY_MIRROR     = 4,
+        GGML_NUMA_STRATEGY_COUNT
+    };
+
+    GGML_BACKEND_API void    ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
+    GGML_BACKEND_API bool    ggml_is_numa(void); // true if init detected that system has >1 NUMA node
+
+    GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+    GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+    GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+    GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+    GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_BACKEND_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+    GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_BACKEND_API void    ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
+
+    GGML_BACKEND_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_BACKEND_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
+
+    GGML_BACKEND_API float   ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
+    GGML_BACKEND_API void    ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
+
+    GGML_BACKEND_API struct ggml_threadpool *      ggml_threadpool_new           (struct ggml_threadpool_params  * params);
+    GGML_BACKEND_API void                          ggml_threadpool_free          (struct ggml_threadpool * threadpool);
+    GGML_BACKEND_API int                           ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);
+    GGML_BACKEND_API void                          ggml_threadpool_pause         (struct ggml_threadpool * threadpool);
+    GGML_BACKEND_API void                          ggml_threadpool_resume        (struct ggml_threadpool * threadpool);
+
+    // ggml_graph_plan() has to be called before ggml_graph_compute()
+    // when plan.work_size > 0, caller must allocate memory for plan.work_data
+    GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
+                  const struct ggml_cgraph * cgraph,
+                                       int   n_threads, /* = GGML_DEFAULT_N_THREADS */
+                    struct ggml_threadpool * threadpool /* = NULL */ );
+    GGML_BACKEND_API enum ggml_status  ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+
+    // same as ggml_graph_compute() but the work data is allocated as a part of the context
+    // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+    GGML_BACKEND_API enum ggml_status  ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
+
+    //
+    // system info
+    //
+
+    // x86
+    GGML_BACKEND_API int ggml_cpu_has_sse3       (void);
+    GGML_BACKEND_API int ggml_cpu_has_ssse3      (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx        (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx_vnni   (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx2       (void);
+    GGML_BACKEND_API int ggml_cpu_has_f16c       (void);
+    GGML_BACKEND_API int ggml_cpu_has_fma        (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512     (void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
+    GGML_BACKEND_API int ggml_cpu_has_amx_int8   (void);
+    // ARM
+    GGML_BACKEND_API int ggml_cpu_has_neon       (void);
+    GGML_BACKEND_API int ggml_cpu_has_arm_fma    (void);
+    GGML_BACKEND_API int ggml_cpu_has_fp16_va    (void);
+    GGML_BACKEND_API int ggml_cpu_has_dotprod    (void);
+    GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
+    GGML_BACKEND_API int ggml_cpu_has_sve        (void);
+    GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
+    // other
+    GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void);
+    GGML_BACKEND_API int ggml_cpu_has_vsx        (void);
+    GGML_BACKEND_API int ggml_cpu_has_wasm_simd  (void);
+    GGML_BACKEND_API int ggml_cpu_has_llamafile  (void);
+
+    // Internal types and functions exposed for tests and benchmarks
+
+    typedef void (*ggml_vec_dot_t)  (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
+                                       const void * GGML_RESTRICT y, size_t by, int nrc);
+
+    struct ggml_type_traits_cpu {
+        ggml_from_float_t        from_float;
+        ggml_vec_dot_t           vec_dot;
+        enum ggml_type           vec_dot_type;
+        int64_t                  nrows; // number of rows to process simultaneously
+    };
+
+    GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
+
+    GGML_BACKEND_API void ggml_cpu_init(void);
+
+    //
+    // CPU backend
+    //
+
+    GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
+
+    GGML_BACKEND_API bool ggml_backend_is_cpu                (ggml_backend_t backend);
+    GGML_BACKEND_API void ggml_backend_cpu_set_n_threads     (ggml_backend_t backend_cpu, int n_threads);
+    GGML_BACKEND_API void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
+    GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
+
+    GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml-cuda.h b/llama/ggml-cuda.h
new file mode 100644
index 000000000..c0fb681e9
--- /dev/null
+++ b/llama/ggml-cuda.h
@@ -0,0 +1,73 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+#ifdef GGML_USE_HIP
+#define GGML_CUDA_NAME "ROCm"
+#define GGML_CUBLAS_NAME "hipBLAS"
+#elif defined(GGML_USE_MUSA)
+#define GGML_CUDA_NAME "MUSA"
+#define GGML_CUBLAS_NAME "muBLAS"
+#else
+#define GGML_CUDA_NAME "CUDA"
+#define GGML_CUBLAS_NAME "cuBLAS"
+#endif
+#define GGML_CUDA_MAX_DEVICES       16
+
+// backend API
+GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
+
+GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
+
+// device buffer
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
+
+// split tensor buffer that splits matrices by rows across multiple devices
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
+
+// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
+
+GGML_BACKEND_API int  ggml_backend_cuda_get_device_count(void);
+GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
+GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
+
+GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
+GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/llama/ggml-cuda/acc.cu b/llama/ggml-cuda/acc.cu
new file mode 100644
index 000000000..9ce47e60d
--- /dev/null
+++ b/llama/ggml-cuda/acc.cu
@@ -0,0 +1,73 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "acc.cuh"
+
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset) {
+    const int i = blockDim.x * blockIdx.x + threadIdx.x;
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
+}
+
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+    int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+    acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
+}
diff --git a/llama/ggml-cuda/acc.cuh b/llama/ggml-cuda/acc.cuh
new file mode 100644
index 000000000..5c12d9066
--- /dev/null
+++ b/llama/ggml-cuda/acc.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_ACC_BLOCK_SIZE 256
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/arange.cu b/llama/ggml-cuda/arange.cu
new file mode 100644
index 000000000..3b67b3b5f
--- /dev/null
+++ b/llama/ggml-cuda/arange.cu
@@ -0,0 +1,60 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "arange.cuh"
+
+static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+    dst[nidx] = start + step * nidx;
+}
+
+static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
+    arange_f32<<>>(dst, ne0, start,  step);
+}
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    float start;
+    float stop;
+    float step;
+    memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
+    memcpy(&stop,  (float *)dst->op_params + 1, sizeof(float));
+    memcpy(&step,  (float *)dst->op_params + 2, sizeof(float));
+
+    int64_t steps = (int64_t)ceil((stop - start) / step);
+    GGML_ASSERT(ggml_nelements(dst) == steps);
+
+    arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
+}
diff --git a/llama/ggml-cuda/arange.cuh b/llama/ggml-cuda/arange.cuh
new file mode 100644
index 000000000..16201546b
--- /dev/null
+++ b/llama/ggml-cuda/arange.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_ARANGE_BLOCK_SIZE 256
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/argmax.cu b/llama/ggml-cuda/argmax.cu
new file mode 100644
index 000000000..8bbfd7c05
--- /dev/null
+++ b/llama/ggml-cuda/argmax.cu
@@ -0,0 +1,117 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include 
+#include 
+
+#include "argmax.cuh"
+#include "common.cuh"
+#include "sum.cuh"
+
+static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
+    const int64_t row = blockIdx.x;
+
+    float maxval = -FLT_MAX;
+    int   argmax = -1;
+    const float * rowx = x + row * ncols;
+
+    for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
+        const float val = rowx[col];
+        if (val > maxval) {
+            maxval = val;
+            argmax = col;
+        }
+    }
+
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+        const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+        if (val > maxval) {
+            maxval = val;
+            argmax = col;
+        }
+    }
+
+    const int n_warps = blockDim.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    if (n_warps > 1) {
+        constexpr int    max_warps = 1024 / WARP_SIZE;
+        __shared__ float shared_maxval[max_warps];
+        __shared__ int   shared_argmax[max_warps];
+        if (lane_id == 0) {
+            shared_maxval[warp_id] = maxval;
+            shared_argmax[warp_id] = argmax;
+        }
+
+        __syncthreads();
+
+        if (warp_id == 0) {
+            if (lane_id < n_warps) {
+                maxval = shared_maxval[lane_id];
+                argmax = shared_argmax[lane_id];
+            }
+#pragma unroll
+            for (int offset = 16; offset > 0; offset >>= 1) {
+                const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+                const int   col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+                if (val > maxval) {
+                    maxval = val;
+                    argmax = col;
+                }
+            }
+        }
+    }
+
+    if (warp_id == 0 && lane_id == 0) {
+        dst[row] = argmax;
+    }
+}
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    int32_t     * dst_d  = (int32_t     *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t num_blocks = nrows;
+    const int64_t num_threads = std::min(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
+    const dim3 blocks_dim(num_threads, 1, 1);
+    const dim3 blocks_num(num_blocks, 1, 1);
+
+    argmax_f32<<>>(src0_d, dst_d, ne00);
+}
diff --git a/llama/ggml-cuda/argmax.cuh b/llama/ggml-cuda/argmax.cuh
new file mode 100644
index 000000000..805a90d8c
--- /dev/null
+++ b/llama/ggml-cuda/argmax.cuh
@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/argsort.cu b/llama/ggml-cuda/argsort.cu
new file mode 100644
index 000000000..d9aaaa13c
--- /dev/null
+++ b/llama/ggml-cuda/argsort.cu
@@ -0,0 +1,130 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "argsort.cuh"
+
+template
+static inline __device__ void ggml_cuda_swap(T & a, T & b) {
+    T tmp = a;
+    a = b;
+    b = tmp;
+}
+
+template
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
+    // bitonic sort
+    int col = threadIdx.x;
+    int row = blockIdx.y;
+
+    if (col >= ncols_pad) {
+        return;
+    }
+
+    const float * x_row = x + row * ncols;
+    extern __shared__ int dst_row[];
+
+    // initialize indices
+    dst_row[col] = col;
+
+    __syncthreads();
+
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            __syncthreads();
+        }
+    }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
+}
+
+static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+    // bitonic sort requires ncols to be power of 2
+    const int ncols_pad = next_power_of_2(ncols);
+
+    const dim3 block_dims(ncols_pad, 1, 1);
+    const dim3 block_nums(1, nrows, 1);
+    const size_t shared_mem = ncols_pad * sizeof(int);
+
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
+
+    if (order == GGML_SORT_ORDER_ASC) {
+        k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad);
+    } else if (order == GGML_SORT_ORDER_DESC) {
+        k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad);
+    } else {
+        GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+    argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+}
diff --git a/llama/ggml-cuda/argsort.cuh b/llama/ggml-cuda/argsort.cuh
new file mode 100644
index 000000000..0d8427bb1
--- /dev/null
+++ b/llama/ggml-cuda/argsort.cuh
@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/binbcast.cu b/llama/ggml-cuda/binbcast.cu
new file mode 100644
index 000000000..40b9fcbe1
--- /dev/null
+++ b/llama/ggml-cuda/binbcast.cu
@@ -0,0 +1,381 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "binbcast.cuh"
+#include 
+
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+    return b;
+    GGML_UNUSED(a);
+}
+
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+    return a + b;
+}
+
+static __device__ __forceinline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
+
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+    return a / b;
+}
+
+template
+static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
+        /*int s10,*/ int s11, int s12, int s13) {
+    const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
+    const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
+    const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
+    const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
+}
+
+template
+static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
+        /*int s10,*/ int s11, int s12, int s13) {
+
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
+
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+}
+
+template 
+static __global__ void k_repeat_back(
+    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+
+    const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+    const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
+    const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+
+    if (tid0 >= ne0) {
+        return;
+    }
+
+    T sum = 0;
+    for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+        for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+            for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+                sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+            }
+        }
+    }
+    dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+}
+
+template
+struct bin_bcast_cuda {
+    template
+    void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+            const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+            cudaStream_t stream) {
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
+
+        int nr[4] = { nr0, nr1, nr2, nr3 };
+
+        // collapse dimensions until first broadcast dimension
+        int64_t cne[] = {ne0, ne1, ne2, ne3};
+        int64_t cne0[] = {ne00, ne01, ne02, ne03};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+
+        size_t cnb[] = {nb0, nb1, nb2, nb3};
+        size_t cnb0[] = {nb00, nb01, nb02, nb03};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
+
+        auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
+
+        if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+            for (int i = 0; i < 4; i++) {
+                if (nr[i] != 1) {
+                    break;
+                }
+                if (i > 0) {
+                    collapse_nb(cnb, cne);
+                    collapse_nb(cnb0, cne0);
+                    collapse_nb(cnb1, cne1);
+                    collapse(cne);
+                    collapse(cne0);
+                    collapse(cne1);
+                }
+            }
+        }
+
+        {
+            int64_t ne0 = cne[0];
+            int64_t ne1 = cne[1];
+            int64_t ne2 = cne[2];
+            int64_t ne3 = cne[3];
+
+            //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+            //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+            //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+            //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
+
+            size_t nb0 = cnb[0];
+            size_t nb1 = cnb[1];
+            size_t nb2 = cnb[2];
+            size_t nb3 = cnb[3];
+
+            size_t nb00 = cnb0[0];
+            size_t nb01 = cnb0[1];
+            size_t nb02 = cnb0[2];
+            size_t nb03 = cnb0[3];
+
+            size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
+
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
+
+            size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
+
+            size_t s00 = nb00 / sizeof(src0_t);
+            size_t s01 = nb01 / sizeof(src0_t);
+            size_t s02 = nb02 / sizeof(src0_t);
+            size_t s03 = nb03 / sizeof(src0_t);
+
+            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s00 == 1);
+            GGML_ASSERT(s10 == 1);
+
+            const int block_size = 128;
+
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+            dim3 block_dims;
+            block_dims.x = std::min(hne0, block_size);
+            block_dims.y = std::min(ne1, block_size / block_dims.x);
+            block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+            dim3 block_nums(
+                (hne0 + block_dims.x - 1) / block_dims.x,
+                (ne1 + block_dims.y - 1) / block_dims.y,
+                (ne2*ne3 + block_dims.z - 1) / block_dims.z
+            );
+
+            if (block_nums.z > 65535) {
+                // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                k_bin_bcast_unravel<<>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
+                    /* s10, */ s11, s12, s13);
+            } else {
+                k_bin_bcast<<>>(
+                    src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00, */ s01, s02, s03,
+                    /* s10, */ s11, s12, s13);
+            }
+        }
+    }
+};
+
+template 
+static void repeat_back_cuda(
+    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
+    k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+}
+
+template
+static void ggml_cuda_op_bin_bcast(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    GGML_ASSERT(dst->ne[3] == 1);
+
+    switch (dst->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            float       * dst_d  = (float       *) dst->data;
+            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+        } break;
+        default: {
+            GGML_ASSERT(false);
+        } break;
+    }
+}
diff --git a/llama/ggml-cuda/binbcast.cuh b/llama/ggml-cuda/binbcast.cuh
new file mode 100644
index 000000000..3acee0d07
--- /dev/null
+++ b/llama/ggml-cuda/binbcast.cuh
@@ -0,0 +1,35 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/clamp.cu b/llama/ggml-cuda/clamp.cu
new file mode 100644
index 000000000..2df1076ca
--- /dev/null
+++ b/llama/ggml-cuda/clamp.cu
@@ -0,0 +1,60 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "clamp.cuh"
+
+static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
+
+static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+    clamp_f32<<>>(x, dst, min, max, k);
+}
+
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float min;
+    float max;
+    memcpy(&min, dst->op_params, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+    clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
+}
diff --git a/llama/ggml-cuda/clamp.cuh b/llama/ggml-cuda/clamp.cuh
new file mode 100644
index 000000000..3f74a8807
--- /dev/null
+++ b/llama/ggml-cuda/clamp.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CLAMP_BLOCK_SIZE 256
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/common.cuh b/llama/ggml-cuda/common.cuh
new file mode 100644
index 000000000..2a40b8499
--- /dev/null
+++ b/llama/ggml-cuda/common.cuh
@@ -0,0 +1,725 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-cuda.h"
+
+#include 
+#include 
+
+#if defined(GGML_USE_HIP)
+#define GGML_COMMON_DECL_HIP
+#define GGML_COMMON_IMPL_HIP
+#else
+#define GGML_COMMON_DECL_CUDA
+#define GGML_COMMON_IMPL_CUDA
+#if defined(GGML_USE_MUSA)
+#define GGML_COMMON_DECL_MUSA
+#define GGML_COMMON_IMPL_MUSA
+#endif
+#endif
+#include "ggml-common.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if defined(GGML_USE_HIP)
+#include "vendors/hip.h"
+#elif defined(GGML_USE_MUSA)
+#include "vendors/musa.h"
+#else
+#include "vendors/cuda.h"
+#endif // defined(GGML_USE_HIP)
+
+#define STRINGIZE_IMPL(...) #__VA_ARGS__
+#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
+
+#define WARP_SIZE 32
+#define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+#define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
+
+#define GGML_CUDA_CC_PASCAL     600
+#define GGML_CUDA_CC_DP4A       610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA      700
+#define GGML_CUDA_CC_TURING     750
+#define GGML_CUDA_CC_AMPERE     800
+#define GGML_CUDA_CC_OFFSET_AMD 1000000
+
+// GCN/CNDA, wave size is 64
+#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
+#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
+#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
+#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
+#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
+#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 942)  // MI300
+
+// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
+#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
+#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
+#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
+
+#define GGML_CUDA_CC_QY1        210
+#define GGML_CUDA_CC_QY2        220
+
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+#define GGML_CUDA_MAX_STREAMS 8
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
+
+#define CUDA_CHECK_GEN(err, success, error_fn)                                      \
+     do {                                                                           \
+        auto err_ = (err);                                                          \
+        if (err_ != (success)) {                                                    \
+            ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_));    \
+        }                                                                           \
+    } while (0)
+
+#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
+
+#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
+    static const char * cublas_get_error_str(const cublasStatus_t err) {
+        return cublasGetStatusString(err);
+    }
+#else
+    static const char * cublas_get_error_str(const cublasStatus_t err) {
+        switch (err) {
+            case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
+            case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
+            case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
+            case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
+            case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
+            case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
+            case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
+            case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
+            case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
+            default: return "unknown error";
+        }
+    }
+#endif // CUDART_VERSION >= 12000
+
+#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
+
+#if !defined(GGML_USE_HIP)
+static const char * cu_get_error_str(CUresult err) {
+    const char * err_str;
+    cuGetErrorString(err, &err_str);
+    return err_str;
+}
+#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
+#endif
+
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11100
+
+#ifdef GGML_CUDA_F16
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+#endif // GGML_CUDA_F16
+
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#define FP16_AVAILABLE
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#define FP16_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#define INT8_MMA_AVAILABLE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+
+#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#define FLASH_ATTN_AVAILABLE
+#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+
+static constexpr bool fast_fp16_available(const int cc) {
+    return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
+}
+
+static constexpr bool fp16_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
+}
+
+static constexpr bool int8_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
+}
+
+[[noreturn]]
+static __device__ void no_device_code(
+    const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
+
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+    printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
+           file_name, line, function_name, arch);
+    GGML_UNUSED(arch_list);
+#else
+    printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
+           file_name, line, function_name, arch, arch_list);
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+    __trap();
+
+    GGML_UNUSED(no_device_code); // suppress unused function warning
+}
+
+#ifdef __CUDA_ARCH__
+#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
+#else
+#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
+#endif // __CUDA_ARCH__
+
+static __device__ __forceinline__ int warp_reduce_sum(int x) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+    return __reduce_add_sync(0xffffffff, x);
+#else
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    }
+    return x;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+}
+
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    }
+    return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
+    }
+    return a;
+}
+
+static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
+        reinterpret_cast(a.x) +=  __low2half(a_other);
+        reinterpret_cast(a.y) += __high2half(a_other);
+    }
+    return a;
+#else
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
+    }
+    return a;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+
+#else
+    NO_DEVICE_CODE;
+    return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+    for (int offset = 16; offset > 0; offset >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+    }
+    return x;
+}
+
+static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
+#ifdef FP16_AVAILABLE
+
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+    return __float2half(fmaxf(__half2float(a), __half2float(b)));
+#else
+    return __hmax(a, b);
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+
+#else
+   NO_DEVICE_CODE;
+   GGML_UNUSED(b);
+   return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+
+#if CUDART_VERSION >= CUDART_HMAX
+    return __hmax2(a, b);
+#else
+    half2 ret;
+    reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
+    reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
+    return ret;
+#endif // CUDART_VERSION >= CUDART_HMAX
+
+#else
+    GGML_UNUSED(a);
+    GGML_UNUSED(b);
+    NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+}
+
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#pragma unroll
+   for (int offset = 16; offset > 0; offset >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+   }
+   return x;
+#else
+   GGML_UNUSED(x);
+   NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+}
+
+#if CUDART_VERSION < CUDART_HMASK
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+    const uint32_t mask_low  = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+    const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+    return mask_low | mask_high;
+}
+#endif // CUDART_VERSION < CUDART_HMASK
+
+static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
+    c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(RDNA3)
+    c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(__gfx1010__) || defined(__gfx900__)
+    int tmp1;
+    int tmp2;
+    asm("\n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+        v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+        v_add3_u32 %0, %1, %2, %0 \n \
+        "
+        : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+        : "v"(a), "v"(b)
+    );
+#else
+    const int8x4_t va = reinterpret_cast(a);
+    const int8x4_t vb = reinterpret_cast(b);
+    c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+    return c;
+
+#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+    return __dp4a(a, b, c);
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+    const int8_t * a8 = (const int8_t *) &a;
+    const int8_t * b8 = (const int8_t *) &b;
+    return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+}
+
+// TODO: move to ggml-common.h
+static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
+
+static __device__ __forceinline__ float get_alibi_slope(
+    const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
+) {
+    if (max_bias <= 0.0f) {
+        return 1.0f;
+    }
+    const float base = h < n_head_log2 ? m0 : m1;
+    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+    return powf(base, exph);
+}
+
+template 
+struct ggml_cuda_type_traits;
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = 1;
+    static constexpr int qr = 1;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK4_0;
+    static constexpr int qr = QR4_0;
+    static constexpr int qi = QI4_0;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK4_1;
+    static constexpr int qr = QR4_1;
+    static constexpr int qi = QI4_1;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK5_0;
+    static constexpr int qr = QR5_0;
+    static constexpr int qi = QI5_0;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK5_1;
+    static constexpr int qr = QR5_1;
+    static constexpr int qi = QI5_1;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK8_0;
+    static constexpr int qr = QR8_0;
+    static constexpr int qi = QI8_0;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_K;
+    static constexpr int qi = QI2_K;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_K;
+    static constexpr int qi = QI3_K;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_K;
+    static constexpr int qi = QI4_K;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR5_K;
+    static constexpr int qi = QI5_K;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR6_K;
+    static constexpr int qi = QI6_K;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XXS;
+    static constexpr int qi = QI2_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_XS;
+    static constexpr int qi = QI2_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR2_S;
+    static constexpr int qi = QI2_S;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_XXS;
+    static constexpr int qi = QI3_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_S;
+    static constexpr int qi = QI1_S;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR1_M;
+    static constexpr int qi = QI1_M;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK4_NL;
+    static constexpr int qr = QR4_NL;
+    static constexpr int qi = QI4_NL;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR4_XS;
+    static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits {
+    static constexpr int qk = QK_K;
+    static constexpr int qr = QR3_S;
+    static constexpr int qi = QI3_S;
+};
+
+//////////////////////
+
+struct ggml_cuda_device_info {
+    int device_count;
+
+    struct cuda_device_info {
+        int     cc;                 // compute capability
+        int     nsm;                // number of streaming multiprocessors
+        size_t  smpb;               // max. shared memory per block
+        size_t  smpbo;              // max. shared memory per block (with opt-in)
+        bool    vmm;                // virtual memory support
+        size_t  vmm_granularity;    // granularity of virtual memory
+        size_t  total_vram;
+    };
+
+    cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
+
+    std::array default_tensor_split = {};
+};
+
+const ggml_cuda_device_info & ggml_cuda_info();
+
+void ggml_cuda_set_device(int device);
+int ggml_cuda_get_device();
+
+struct ggml_cuda_pool {
+    virtual ~ggml_cuda_pool() = default;
+
+    virtual void * alloc(size_t size, size_t * actual_size) = 0;
+    virtual void free(void * ptr, size_t size) = 0;
+};
+
+template
+struct ggml_cuda_pool_alloc {
+    ggml_cuda_pool * pool = nullptr;
+    T * ptr = nullptr;
+    size_t actual_size = 0;
+
+    ggml_cuda_pool_alloc() = default;
+
+    explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
+    }
+
+    ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
+        alloc(size);
+    }
+
+    ~ggml_cuda_pool_alloc() {
+        if (ptr != nullptr) {
+            pool->free(ptr, actual_size);
+        }
+    }
+
+    // size is in number of elements
+    T * alloc(size_t size) {
+        GGML_ASSERT(pool != nullptr);
+        GGML_ASSERT(ptr == nullptr);
+        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
+        return ptr;
+    }
+
+    T * alloc(ggml_cuda_pool & pool, size_t size) {
+        this->pool = &pool;
+        return alloc(size);
+    }
+
+    T * get() {
+        return ptr;
+    }
+
+    ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
+    ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
+    ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
+    ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
+};
+
+
+// backend interface
+
+struct ggml_tensor_extra_gpu {
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+    cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
+};
+
+
+#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_graph_node_properties {
+    void * node_address;
+    ggml_op node_op;
+    int64_t ne[GGML_MAX_DIMS];
+    size_t nb[GGML_MAX_DIMS];
+    void * src_address[GGML_MAX_SRC];
+    int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+};
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+    ~ggml_cuda_graph() {
+        if (instance != nullptr) {
+            CUDA_CHECK(cudaGraphExecDestroy(instance));
+        }
+        if (graph != nullptr) {
+            CUDA_CHECK(cudaGraphDestroy(graph));
+        }
+    }
+    cudaGraph_t graph = nullptr;
+    cudaGraphExec_t instance = nullptr;
+    size_t num_nodes = 0;
+    std::vector nodes;
+    std::vector params;
+    bool disable_due_to_gpu_arch = false;
+    bool disable_due_to_too_many_updates = false;
+    bool disable_due_to_failed_graph_capture = false;
+    int number_consecutive_updates = 0;
+    std::vector ggml_graph_properties;
+    std::vector updated_kernel_arg;
+#endif
+};
+
+struct ggml_backend_cuda_context {
+    int device;
+    std::string name;
+    cudaEvent_t copy_event = nullptr;
+
+    cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
+    cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+    std::unique_ptr cuda_graph;
+
+    explicit ggml_backend_cuda_context(int device) :
+        device(device),
+        name(GGML_CUDA_NAME + std::to_string(device)) {
+    }
+
+    ~ggml_backend_cuda_context() {
+        if (copy_event != nullptr) {
+            CUDA_CHECK(cudaEventDestroy(copy_event));
+        }
+        for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
+            for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
+                if (streams[i][j] != nullptr) {
+                    CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
+                }
+            }
+            if (cublas_handles[i] != nullptr) {
+                CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
+            }
+        }
+    }
+
+    cudaStream_t stream(int device, int stream) {
+        if (streams[device][stream] == nullptr) {
+            ggml_cuda_set_device(device);
+            CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
+        }
+        return streams[device][stream];
+    }
+
+    cudaStream_t stream() {
+        return stream(device, 0);
+    }
+
+    cublasHandle_t cublas_handle(int device) {
+        if (cublas_handles[device] == nullptr) {
+            ggml_cuda_set_device(device);
+            CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
+            CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
+        }
+        return cublas_handles[device];
+    }
+
+    cublasHandle_t cublas_handle() {
+        return cublas_handle(device);
+    }
+
+    // pool
+    std::unique_ptr pools[GGML_CUDA_MAX_DEVICES];
+
+    static std::unique_ptr new_pool_for_device(int device);
+
+    ggml_cuda_pool & pool(int device) {
+        if (pools[device] == nullptr) {
+            pools[device] = new_pool_for_device(device);
+        }
+        return *pools[device];
+    }
+
+    ggml_cuda_pool & pool() {
+        return pool(device);
+    }
+};
diff --git a/llama/ggml-cuda/concat.cu b/llama/ggml-cuda/concat.cu
new file mode 100644
index 000000000..d8c473913
--- /dev/null
+++ b/llama/ggml-cuda/concat.cu
@@ -0,0 +1,247 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "concat.cuh"
+
+// contiguous kernels
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (nidx < ne00) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * gridDim.y;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            (nidx - ne00) +
+            blockIdx.y * (ne0 - ne00) +
+            blockIdx.z * (ne0 - ne00) * gridDim.y;
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (blockIdx.y < ne01) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * ne01;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            (blockIdx.y - ne01) * ne0 +
+            blockIdx.z * ne0 * (gridDim.y - ne01);
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+
+    if (blockIdx.z < ne02) { // src0
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            blockIdx.z * ne0 * gridDim.y;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne0 +
+            (blockIdx.z - ne02) * ne0 *  gridDim.y;
+        dst[offset_dst] = y[offset_src];
+    }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2);
+    if (dim == 0) {
+        concat_f32_dim0<<>>(x, y, dst, ne0, ne00);
+        return;
+    }
+    if (dim == 1) {
+        concat_f32_dim1<<>>(x, y, dst, ne0, ne01);
+        return;
+    }
+    concat_f32_dim2<<>>(x, y, dst, ne0, ne02);
+}
+
+// non-contiguous kernel (slow)
+template 
+static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
+    concat_f32_non_cont(
+        const char * src0,
+        const char * src1,
+              char * dst,
+           int64_t   ne00,
+           int64_t   ne01,
+           int64_t   ne02,
+           int64_t   ne03,
+          uint64_t   nb00,
+          uint64_t   nb01,
+          uint64_t   nb02,
+          uint64_t   nb03,
+           int64_t /*ne10*/,
+           int64_t /*ne11*/,
+           int64_t /*ne12*/,
+           int64_t /*ne13*/,
+          uint64_t   nb10,
+          uint64_t   nb11,
+          uint64_t   nb12,
+          uint64_t   nb13,
+           int64_t   ne0,
+           int64_t /*ne1*/,
+           int64_t /*ne2*/,
+           int64_t /*ne3*/,
+          uint64_t   nb0,
+          uint64_t   nb1,
+          uint64_t   nb2,
+          uint64_t   nb3){
+    static_assert(dim >= 0 && dim <= 3, "dim must be between 0 and 3");
+
+    const int64_t i3 = blockIdx.z;
+    const int64_t i2 = blockIdx.y;
+    const int64_t i1 = blockIdx.x;
+
+    const float * x;
+
+    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+            x = (const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
+        } else {
+            if constexpr (dim == 0) {
+                x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
+            } else if constexpr (dim == 1) {
+                x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
+            } else if constexpr (dim == 2) {
+                x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
+            } else if constexpr (dim == 3) {
+                x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
+            }
+        }
+
+        float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+        *y = *x;
+    }
+}
+
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    cudaStream_t stream = ctx.stream();
+
+    const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+        const float * src0_d = (const float *)src0->data;
+        const float * src1_d = (const float *)src1->data;
+
+        float * dst_d = (float *)dst->data;
+
+        if (dim != 3) {
+            for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+                concat_f32_cuda(
+                        src0_d + i3 * (src0->nb[3] / 4),
+                        src1_d + i3 * (src1->nb[3] / 4),
+                        dst_d + i3 * ( dst->nb[3] / 4),
+                        src0->ne[0], src0->ne[1], src0->ne[2],
+                        dst->ne[0],  dst->ne[1],  dst->ne[2], dim, stream);
+            }
+        } else {
+            const size_t size0 = ggml_nbytes(src0);
+            const size_t size1 = ggml_nbytes(src1);
+
+            CUDA_CHECK(cudaMemcpyAsync(dst_d,           src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+            CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+        }
+    } else {
+        dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+        auto launch_kernel = [&](auto dim) {
+            concat_f32_non_cont<<>>(
+                (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
+                src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+                src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
+                dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+                dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
+        };
+        switch (dim) {
+            case 0:
+                launch_kernel(std::integral_constant{});
+                break;
+            case 1:
+                launch_kernel(std::integral_constant{});
+                break;
+            case 2:
+                launch_kernel(std::integral_constant{});
+                break;
+            case 3:
+                launch_kernel(std::integral_constant{});
+                break;
+            default:
+                GGML_ABORT("Invalid dim: %d", dim);
+                break;
+        }
+    }
+}
diff --git a/llama/ggml-cuda/concat.cuh b/llama/ggml-cuda/concat.cuh
new file mode 100644
index 000000000..ba2b67ec1
--- /dev/null
+++ b/llama/ggml-cuda/concat.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CONCAT_BLOCK_SIZE 256
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/conv-transpose-1d.cu b/llama/ggml-cuda/conv-transpose-1d.cu
new file mode 100644
index 000000000..da53e9469
--- /dev/null
+++ b/llama/ggml-cuda/conv-transpose-1d.cu
@@ -0,0 +1,113 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "conv-transpose-1d.cuh"
+
+static  __global__ void conv_transpose_1d_kernel(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst) {
+    int global_index = threadIdx.x + blockIdx.x * blockDim.x;
+    if (global_index >= output_size) {
+        return;
+    }
+
+    int out_index = global_index / dst_ne0;
+
+    float accumulator = 0;
+
+    for (int c = 0; c < src0_ne2; c++) {
+        int idx = global_index % dst_ne0;
+
+        int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+        int input_offset = src1_ne0 * c;
+
+        for (int i = 0; i < src1_ne0; i++) {
+            if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+                continue;
+            }
+            int weight_idx = idx - i*s0;
+
+            float kernel_weight = src0[kernel_offset + weight_idx];
+            float input_value =  src1[input_offset+i];
+
+            accumulator += kernel_weight * input_value;
+        }
+    }
+    dst[global_index] = accumulator;
+}
+
+static void conv_transpose_1d_f32_f32_cuda(
+        const int s0, const int p0, const int d0, const int output_size,
+        const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+        const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+        const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+        const float * src0, const float * src1,  float * dst,
+        cudaStream_t stream) {
+
+    const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
+    conv_transpose_1d_kernel<<>>(
+        s0,p0,d0,output_size,
+        src0_ne0, src0_ne1,  src0_ne2, src0_ne3,
+        src1_ne0, src1_ne1,  src1_ne2, src1_ne3,
+        dst_ne0,  dst_ne1,   dst_ne2,  dst_ne3,
+        src0,src1, dst);
+}
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+
+    const int s0 = opts[0];
+    const int p0 = 0;//opts[3];
+    const int d0 = 1;//opts[4];
+
+    const int64_t kernel_size = ggml_nelements(src0);
+    const int64_t input_size = ggml_nelements(src1);
+    const int64_t output_size = ggml_nelements(dst);
+
+    conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+        src0_d, src1_d, dst_d, stream);
+}
diff --git a/llama/ggml-cuda/conv-transpose-1d.cuh b/llama/ggml-cuda/conv-transpose-1d.cuh
new file mode 100644
index 000000000..53c3beefd
--- /dev/null
+++ b/llama/ggml-cuda/conv-transpose-1d.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/convert.cu b/llama/ggml-cuda/convert.cu
new file mode 100644
index 000000000..6ddb87fc3
--- /dev/null
+++ b/llama/ggml-cuda/convert.cu
@@ -0,0 +1,714 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "convert.cuh"
+#include "dequantize.cuh"
+
+#define CUDA_Q8_0_NE_ALIGN 2048
+
+template 
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+    const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
+
+    if (i >= k) {
+        return;
+    }
+
+    const int64_t ib = i/qk; // block index
+    const int64_t iqs = (i%qk)/qr; // quant index
+    const int64_t iybs = i - i%qk; // y block start index
+    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(vx, ib, iqs, v);
+
+    y[iybs + iqs + 0]        = v.x;
+    y[iybs + iqs + y_offset] = v.y;
+}
+
+template 
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+    constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
+
+    const int64_t   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
+    const int * x0 = ((int *) vx) + blockIdx.x * nint;
+    half2 * y2 = (half2 *) (y + i0);
+
+    __shared__ int vals[nint];
+
+#pragma unroll
+    for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
+        if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
+            break;
+        }
+
+        const int ix = ix0 + threadIdx.x;
+        vals[ix] = x0[ix];
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
+        if (need_check && i0 + iy + 2*threadIdx.x >= k) {
+            return;
+        }
+
+        const half * b0 = ((const half  *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
+        const half    d = *b0;
+        const char2  qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
+
+        y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
+    }
+#else
+    GGML_UNUSED(vx);
+    GGML_UNUSED(y);
+    GGML_UNUSED(k);
+    NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+}
+
+template
+static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+    const int64_t i = blockIdx.x;
+
+    // assume 32 threads
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+    const float d = __half2float(x->d);
+    const float dm = -8*d;
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d * (q[l] & 0xF) + dm;
+        y[l+16] = d * (q[l] >>  4) + dm;
+    }
+}
+
+template
+static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+    const int64_t i = blockIdx.x;
+
+    // assume 32 threads
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t ib = 8*i + ir;
+    if (ib >= nb32) {
+        return;
+    }
+
+    dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+    const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+    const float2 d = __half22float2(x->dm);
+
+    const uint8_t * q = x->qs + 4*il;
+
+    for (int l = 0; l < 4; ++l) {
+        y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
+        y[l+16] = d.x * (q[l] >>  4) + d.y;
+    }
+}
+
+//================================== k-quants
+
+template
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_q2_K * x = (const block_q2_K *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t n   = tid/32;
+    const int64_t l   = tid - 32*n;
+    const int64_t is  = 8*n + l/16;
+
+    const uint8_t q = x[i].qs[32*n + l];
+    dst_t * y = yy + i*QK_K + 128*n;
+
+    float dall = __low2half(x[i].dm);
+    float dmin = __high2half(x[i].dm);
+    y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+    y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+    y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+    y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+}
+
+template
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i = blockIdx.x;
+    const block_q3_K * x = (const block_q3_K *) vx;
+
+    const int64_t r = threadIdx.x/4;
+    const int64_t tid = r/2;
+    const int64_t is0 = r%2;
+    const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
+    const int64_t n = tid / 4;
+    const int64_t j = tid - 4*n;
+
+    uint8_t m = 1 << (4*n + j);
+    int64_t is = 8*n + 2*j + is0;
+    int shift = 2*j;
+
+    int8_t us = is <  4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+                is <  8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+                is < 12 ? (x[i].scales[is-8] >>  4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+                          (x[i].scales[is-8] >>  4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+    float d_all = x[i].d;
+    float dl = d_all * (us - 32);
+
+    dst_t * y = yy + i*QK_K + 128*n + 32*j;
+    const uint8_t * q = x[i].qs + 32*n;
+    const uint8_t * hm = x[i].hmask;
+
+    for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+}
+
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+    if (j < 4) {
+        d = q[j] & 63; m = q[j + 4] & 63;
+    } else {
+        d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+
+template
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q4_K * x = (const block_q4_K *) vx;
+
+    const int64_t i = blockIdx.x;
+
+    // assume 32 threads
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/8;
+    const int64_t ir  = tid%8;
+    const int64_t is  = 2*il;
+    const int64_t n   = 4;
+
+    dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
+
+    const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+    for (int l = 0; l < n; ++l) {
+        y[l + 0] = d1 * (q[l] & 0xF) - m1;
+        y[l +32] = d2 * (q[l] >>  4) - m2;
+    }
+}
+
+template
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q5_K * x = (const block_q5_K *) vx;
+
+    const int64_t i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int64_t tid = threadIdx.x;
+    const int64_t il  = tid/16;   // il is in 0...3
+    const int64_t ir  = tid%16;   // ir is in 0...15
+    const int64_t is  = 2*il;     // is is in 0...6
+
+    dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+    const float dall = __low2half(x[i].dm);
+    const float dmin = __high2half(x[i].dm);
+
+    const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+    const uint8_t * qh = x[i].qh + 2*ir;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, x[i].scales, sc, m);
+    const float d1 = dall * sc; const float m1 = dmin * m;
+    get_scale_min_k4(is + 1, x[i].scales, sc, m);
+    const float d2 = dall * sc; const float m2 = dmin * m;
+
+    uint8_t   hm  = 1 << (2*il);
+    y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+    y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+    hm <<= 1;
+    y[32] = d2 * ((ql[ 0] >>  4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+    y[33] = d2 * ((ql[ 1] >>  4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+}
+
+template
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const block_q6_K * x = (const block_q6_K *) vx;
+
+    const int64_t i = blockIdx.x;
+
+    // assume 64 threads - this is very slightly better than the one below
+    const int64_t tid = threadIdx.x;
+    const int64_t ip  = tid/32;   // ip is 0 or 1
+    const int64_t il  = tid - 32*ip; // 0...32
+    const int64_t is  = 8*ip + il/16;
+
+    dst_t * y = yy + i*QK_K + 128*ip + il;
+
+    const float d = x[i].d;
+
+    const uint8_t * ql = x[i].ql + 64*ip + il;
+    const uint8_t   qh = x[i].qh[32*ip + il];
+    const int8_t  * sc = x[i].scales + is;
+
+    y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+    y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+    y[64] = d * sc[4] * ((int8_t)((ql[ 0]  >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+    y[96] = d * sc[6] * ((int8_t)((ql[32]  >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
+
+template
+static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq2_xxs * x = (const block_iq2_xxs  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * aux8 = (const uint8_t *)q2;
+    const uint8_t  * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
+    const uint32_t aux32 = q2[2] | (q2[3] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template
+static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * q2 = x[i].qs + 4*ib;
+    const uint8_t  * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template
+static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq2_s * x = (const block_iq2_s *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
+    const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+    const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
+    for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template
+static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq3_xxs * x = (const block_iq3_xxs  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t  * q3 = x[i].qs + 8*ib;
+    const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+    const uint8_t  * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+    const uint8_t  * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+    const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+}
+
+template
+static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq3_s * x = (const block_iq3_s *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint8_t * qs = x[i].qs + 8*ib;
+    const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+    const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+    const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+    const uint8_t signs = x[i].signs[4*ib + il];
+    for (int j = 0; j < 4; ++j) {
+        y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+        y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+    }
+}
+
+template
+static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq1_s * x = (const block_iq1_s  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
+    const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
+    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
+    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+    grid32[0] &= 0x0f0f0f0f;
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * (q[j] + delta);
+    }
+}
+
+template
+static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq1_m * x = (const block_iq1_m  *) vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+    const uint16_t * sc = (const uint16_t *)x[i].scales;
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+    const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
+    const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
+    uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+    grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
+    grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+    grid32[0] &= 0x0f0f0f0f;
+    for (int j = 0; j < 8; ++j) {
+        y[j] = d * (q[j] + delta);
+    }
+}
+
+template
+static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = (float)x[ib].d;
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+}
+
+template
+static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+    const int64_t i   = blockIdx.x;
+    const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[i].qs + 16*ib + 4*il;
+    const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+        y[j+16] = d * kvalues_iq4nl[q4[j] >>  4];
+    }
+}
+
+template 
+static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
+    dequantize_block<<>>(vx, y, k);
+}
+
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
+    if (k % CUDA_Q8_0_NE_ALIGN == 0) {
+        const bool need_check = false;
+        dequantize_block_q8_0_f16<<>>(vx, y, k);
+    } else {
+        const bool need_check = true;
+        dequantize_block_q8_0_f16<<>>(vx, y, k);
+    }
+}
+
+template
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q2_K<<>>(vx, y);
+}
+
+template
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q3_K<<>>(vx, y);
+}
+
+template
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    dequantize_block_q4_0<<>>(vx, y, nb32);
+}
+
+template
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb32 = k / 32;
+    const int nb = (k + 255) / 256;
+    dequantize_block_q4_1<<>>(vx, y, nb32);
+}
+
+template
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q4_K<<>>(vx, y);
+}
+
+template
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q5_K<<>>(vx, y);
+}
+
+template
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_q6_K<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_xxs<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_xs<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq2_s<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq3_xxs<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq3_s<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq1_s<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_iq4_nl<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = k / QK_K;
+    dequantize_block_iq1_m<<>>(vx, y);
+}
+
+template
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_iq4_xs<<>>(vx, y);
+}
+
+template 
+static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const src_t * x = (src_t *) vx;
+
+    y[i] = x[i];
+}
+
+template 
+static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    convert_unary<<>>(vx, y, k);
+}
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q8_0:
+            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
+                return dequantize_block_q8_0_f16_cuda;
+            }
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_IQ2_XXS:
+            return dequantize_row_iq2_xxs_cuda;
+        case GGML_TYPE_IQ2_XS:
+            return dequantize_row_iq2_xs_cuda;
+        case GGML_TYPE_IQ2_S:
+            return dequantize_row_iq2_s_cuda;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_cuda;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_cuda;
+        case GGML_TYPE_IQ1_M:
+            return dequantize_row_iq1_m_cuda;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_cuda;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_F32:
+            return convert_unary_cuda;
+        default:
+            return nullptr;
+    }
+}
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+            return dequantize_row_q4_0_cuda;
+        case GGML_TYPE_Q4_1:
+            return dequantize_row_q4_1_cuda;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_cuda;
+        case GGML_TYPE_Q2_K:
+            return dequantize_row_q2_K_cuda;
+        case GGML_TYPE_Q3_K:
+            return dequantize_row_q3_K_cuda;
+        case GGML_TYPE_Q4_K:
+            return dequantize_row_q4_K_cuda;
+        case GGML_TYPE_Q5_K:
+            return dequantize_row_q5_K_cuda;
+        case GGML_TYPE_Q6_K:
+            return dequantize_row_q6_K_cuda;
+        case GGML_TYPE_IQ2_XXS:
+            return dequantize_row_iq2_xxs_cuda;
+        case GGML_TYPE_IQ2_XS:
+            return dequantize_row_iq2_xs_cuda;
+        case GGML_TYPE_IQ2_S:
+            return dequantize_row_iq2_s_cuda;
+        case GGML_TYPE_IQ3_XXS:
+            return dequantize_row_iq3_xxs_cuda;
+        case GGML_TYPE_IQ1_S:
+            return dequantize_row_iq1_s_cuda;
+        case GGML_TYPE_IQ1_M:
+            return dequantize_row_iq1_m_cuda;
+        case GGML_TYPE_IQ4_NL:
+            return dequantize_row_iq4_nl_cuda;
+        case GGML_TYPE_IQ4_XS:
+            return dequantize_row_iq4_xs_cuda;
+        case GGML_TYPE_IQ3_S:
+            return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_F16:
+            return convert_unary_cuda;
+        case GGML_TYPE_BF16:
+            return convert_unary_cuda;
+        default:
+            return nullptr;
+    }
+}
diff --git a/llama/ggml-cuda/convert.cuh b/llama/ggml-cuda/convert.cuh
new file mode 100644
index 000000000..27f949e2a
--- /dev/null
+++ b/llama/ggml-cuda/convert.cuh
@@ -0,0 +1,39 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+
+template
+using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
+
+typedef to_t_cuda_t to_fp32_cuda_t;
+typedef to_t_cuda_t to_fp16_cuda_t;
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
diff --git a/llama/ggml-cuda/count-equal.cu b/llama/ggml-cuda/count-equal.cu
new file mode 100644
index 000000000..e4496fc1b
--- /dev/null
+++ b/llama/ggml-cuda/count-equal.cu
@@ -0,0 +1,90 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "count-equal.cuh"
+
+#include 
+
+template 
+static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
+    const int64_t i0 = (int64_t) blockIdx.x*dk;
+    const int64_t i1 = min(i0 + dk, k);
+
+    int nequal = 0;
+
+    for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
+        const T xi = x[i];
+        const T yi = y[i];
+        nequal += xi == yi;
+    }
+
+    nequal = warp_reduce_sum(nequal);
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    atomicAdd((int *) dst, nequal);
+}
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT( dst->type == GGML_TYPE_I64);
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    int64_t * dst_d  = (int64_t *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
+    const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
+
+    CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_I32: {
+            const int * src0_d = (const int *) src0->data;
+            const int * src1_d = (const int *) src1->data;
+            count_equal<<>>(src0_d, src1_d, dst_d, dne, ne);
+        } break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}
diff --git a/llama/ggml-cuda/count-equal.cuh b/llama/ggml-cuda/count-equal.cuh
new file mode 100644
index 000000000..922c6288c
--- /dev/null
+++ b/llama/ggml-cuda/count-equal.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/cpy.cu b/llama/ggml-cuda/cpy.cu
new file mode 100644
index 000000000..ffdef8c43
--- /dev/null
+++ b/llama/ggml-cuda/cpy.cu
@@ -0,0 +1,569 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "cpy.cuh"
+
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+
+static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
+static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = __float2half(*xi);
+}
+
+static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    half * dsti = (half *) cdsti;
+
+    *dsti = *xi;
+}
+
+static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
+template 
+static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                                   const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                   const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                   const int nb12, const int nb13) {
+    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= ne) {
+        return;
+    }
+
+    // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // then combine those indices with the corresponding byte offsets to get the total offsets
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+    cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
+
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = fmaxf(amax, fabsf(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
+
+        dsti->qs[j] = roundf(x0);
+    }
+}
+
+static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
+    const block_q8_0 * xi = (const block_q8_0 *) cxi;
+    float * dsti = (float *) cdsti;
+
+    const float d = (float)xi->d;
+
+    for (int j = 0; j < QK8_0; j++) {
+       dsti[j] = xi->qs[j] * d;
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
+
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
+
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
+
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = vmin;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+
+        const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
+
+static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q5_0 * dsti = (block_q5_0 *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK5_0; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    const float d  = vmax / -16;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK5_0/2 + j]*id;
+
+        const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
+        const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
+
+        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q5_1 * dsti = (block_q5_1 *) cdsti;
+
+    float min = xi[0];
+    float max = xi[0];
+
+    for (int j = 1; j < QK5_1; ++j) {
+        const float v = xi[j];
+        min = v < min ? v : min;
+        max = v > max ? v : max;
+    }
+
+    const float d  = (max - min) / 31;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dsti->dm.x = d;
+    dsti->dm.y = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1/2; ++j) {
+        const float x0 = (xi[0       + j] - min)*id;
+        const float x1 = (xi[QK5_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+        dsti->qs[j]  = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+    }
+    memcpy(dsti->qh, &qh, sizeof(qh));
+}
+
+
+static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
+
+    float amax = 0.0f;
+    float vmax = 0.0f;
+
+    for (int j = 0; j < QK4_NL; ++j) {
+        const float v = xi[j];
+        if (amax < fabsf(v)) {
+            amax = fabsf(v);
+            vmax = v;
+        }
+    }
+
+    float d = vmax / kvalues_iq4nl[0];
+    const float id = d ? 1.0f/d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL/2; ++j) {
+        const float x0 = xi[0        + j]*id;
+        const float x1 = xi[QK4_NL/2 + j]*id;
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+        dsti->qs[j] = xi0 | (xi1 << 4);
+        const float v0 = kvalues_iq4nl[xi0];
+        const float v1 = kvalues_iq4nl[xi1];
+        const float w0 = xi[0        + j]*xi[0        + j];
+        const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
+        sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
+        sumq2 += w0*v0*v0 + w1*v1*v1;
+    }
+
+    dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
+template 
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                 const int nb12, const int nb13) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+template 
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
+                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                 const int nb12, const int nb13) {
+    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+    if (i >= ne) {
+        return;
+    }
+
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+    cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+static void ggml_cpy_f16_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q8_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK8_0 == 0);
+    const int num_blocks = ne / QK8_0;
+    cpy_f32_q<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_q8_0_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = ne;
+    cpy_q_f32<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_0 == 0);
+    const int num_blocks = ne / QK4_0;
+    cpy_f32_q<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_1 == 0);
+    const int num_blocks = ne / QK4_1;
+    cpy_f32_q<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_0_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK5_0 == 0);
+    const int num_blocks = ne / QK5_0;
+    cpy_f32_q<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_1_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK5_1 == 0);
+    const int num_blocks = ne / QK5_1;
+    cpy_f32_q<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_iq4_nl_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    GGML_ASSERT(ne % QK4_NL == 0);
+    const int num_blocks = ne / QK4_NL;
+    cpy_f32_q<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f16_f16_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<<>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne == ggml_nelements(src1));
+
+    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    //GGML_ASSERT(src0->ne[3] == 1);
+
+    const int64_t nb00 = src0->nb[0];
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
+    const int64_t nb03 = src0->nb[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+
+    //GGML_ASSERT(src1->ne[3] == 1);
+
+    const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+    const int64_t nb13 = src1->nb[3];
+
+    cudaStream_t main_stream = ctx.stream();
+
+    char * src0_ddc = (char *) src0->data;
+    char * src1_ddc = (char *) src1->data;
+
+    if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+        GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
+        CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else {
+        GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+    }
+}
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    ggml_cuda_cpy(ctx, src0, dst);
+}
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
+    if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+        return nullptr;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_f32_f16;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+        return (void*) cpy_f32_f16;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+        return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+        return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+        return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        return (void*) cpy_f32_f16;
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_f32_f16;
+    } else {
+        GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+    }
+}
diff --git a/llama/ggml-cuda/cpy.cuh b/llama/ggml-cuda/cpy.cuh
new file mode 100644
index 000000000..79496c4cc
--- /dev/null
+++ b/llama/ggml-cuda/cpy.cuh
@@ -0,0 +1,35 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CPY_BLOCK_SIZE 64
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
diff --git a/llama/ggml-cuda/cross-entropy-loss.cu b/llama/ggml-cuda/cross-entropy-loss.cu
new file mode 100644
index 000000000..5bfddc79b
--- /dev/null
+++ b/llama/ggml-cuda/cross-entropy-loss.cu
@@ -0,0 +1,192 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "cross-entropy-loss.cuh"
+#include "sum.cuh"
+
+#include 
+#include 
+
+static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+    const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
+
+    const int ne_tmp = WARP_SIZE*nclasses;
+
+    extern __shared__ float tmp_all[];
+    float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
+    float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
+
+    // Each warp first loads ne_tmp logits/labels into shared memory:
+    for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
+        const int ig = i0*nclasses + i; // ig == i global
+
+        tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
+        tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
+    }
+
+    // Each thread in the warp then calculates the cross entropy loss for a single row.
+    // TODO: pad in order to avoid shared memory bank conflicts.
+
+    // Find maximum for softmax:
+    float max = -INFINITY;
+    for (int i = 0; i < nclasses; ++i) {
+        max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+    }
+
+    // Calculate log(softmax(logits)) which is just logits - max:
+    float sum = 0.0f;
+    for (int i = 0; i < nclasses; ++i) {
+        float val = tmp_logits[lane_id*nclasses + i] - max;
+        sum += expf(val);
+        tmp_logits[lane_id*nclasses + i] = val;
+    }
+    sum = logf(sum);
+
+    // log(exp(logits - max) / sum) = (logits - max) - log(sum)
+    float loss = 0.0f;
+    for (int i = 0; i < nclasses; ++i) {
+        loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+    }
+    loss = -warp_reduce_sum(loss) / (float)k;
+
+    __syncthreads();
+
+    if (lane_id == 0) {
+        tmp_all[warp_id] = loss;
+    }
+
+    __syncthreads();
+
+    if (warp_id != 0) {
+        return;
+    }
+
+    loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
+    loss = warp_reduce_sum(loss);
+
+    if (lane_id != 0) {
+        return;
+    }
+
+    dst[blockIdx.x] = loss;
+}
+
+static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+    extern __shared__ float tmp[];
+
+    float maxval = -INFINITY;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = logits[blockIdx.x*nclasses + i];
+        maxval = fmaxf(maxval, val);
+        tmp[i] = val;
+    }
+    maxval = warp_reduce_max(maxval);
+
+    float sum = 0.0f;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = expf(tmp[i] - maxval);
+        sum += val;
+        tmp[i] = val;
+    }
+    sum = warp_reduce_sum(sum);
+    const float sm_scale = 1.0f/sum;
+
+    const float d_by_nrows = *loss/gridDim.x;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+    }
+}
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t stream = ctx.stream();
+
+    const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
+    const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
+    const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+
+    ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x);
+
+    cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+
+    // Combine results from individual blocks:
+    sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
+}
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * opt0 = dst->src[2];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(opt0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    const float * opt0_d = (const float *) opt0->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(nrows, 1, 1);
+    const int shmem = ne00*sizeof(float);
+
+    cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+}
diff --git a/llama/ggml-cuda/cross-entropy-loss.cuh b/llama/ggml-cuda/cross-entropy-loss.cuh
new file mode 100644
index 000000000..e816b8df6
--- /dev/null
+++ b/llama/ggml-cuda/cross-entropy-loss.cuh
@@ -0,0 +1,33 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/dequantize.cuh b/llama/ggml-cuda/dequantize.cuh
new file mode 100644
index 000000000..016de0db6
--- /dev/null
+++ b/llama/ggml-cuda/dequantize.cuh
@@ -0,0 +1,129 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
+
+#ifdef GGML_CUDA_F16
+    v = __hsub2(v, {8.0f, 8.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 8.0f) * d;
+    v.y = (v.y - 8.0f) * d;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
+
+    const int vui = x[ib].qs[iqs];
+
+    v.x = vui & 0xF;
+    v.y = vui >> 4;
+
+#ifdef GGML_CUDA_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
+
+#ifdef GGML_CUDA_F16
+    v = __hsub2(v, {16.0f, 16.0f});
+    v = __hmul2(v, {d, d});
+#else
+    v.x = (v.x - 16.0f) * d;
+    v.y = (v.y - 16.0f) * d;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const dfloat d = __low2half(x[ib].dm);
+    const dfloat m = __high2half(x[ib].dm);
+
+    uint32_t qh;
+    memcpy(&qh, x[ib].qh, sizeof(qh));
+
+    const int xh_0 = ((qh >> (iqs +  0)) << 4) & 0x10;
+    const int xh_1 = ((qh >> (iqs + 12))     ) & 0x10;
+
+    v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+    v.y = ((x[ib].qs[iqs] >>  4) | xh_1);
+
+#ifdef GGML_CUDA_F16
+    v = __hmul2(v, {d, d});
+    v = __hadd2(v, {m, m});
+#else
+    v.x = (v.x * d) + m;
+    v.y = (v.y * d) + m;
+#endif // GGML_CUDA_F16
+}
+
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const dfloat d = x[ib].d;
+
+    v.x = x[ib].qs[iqs + 0];
+    v.y = x[ib].qs[iqs + 1];
+
+#ifdef GGML_CUDA_F16
+    v = __hmul2(v, {d, d});
+#else
+    v.x *= d;
+    v.y *= d;
+#endif // GGML_CUDA_F16
+}
diff --git a/llama/ggml-cuda/diagmask.cu b/llama/ggml-cuda/diagmask.cu
new file mode 100644
index 000000000..e80a953ae
--- /dev/null
+++ b/llama/ggml-cuda/diagmask.cu
@@ -0,0 +1,66 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "diagmask.cuh"
+
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+    const int col = blockDim.y*blockIdx.y + threadIdx.y;
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int i = row*ncols + col;
+    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
+}
+
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+    const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
+    const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+    const dim3 block_nums(nrows_x, block_num_x, 1);
+    diag_mask_inf_f32<<>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int nrows0 = ggml_nrows(src0);
+
+    const int n_past = ((int32_t *) dst->op_params)[0];
+
+    diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
+}
diff --git a/llama/ggml-cuda/diagmask.cuh b/llama/ggml-cuda/diagmask.cuh
new file mode 100644
index 000000000..76162837f
--- /dev/null
+++ b/llama/ggml-cuda/diagmask.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/fattn-common.cuh b/llama/ggml-cuda/fattn-common.cuh
new file mode 100644
index 000000000..011654d31
--- /dev/null
+++ b/llama/ggml-cuda/fattn-common.cuh
@@ -0,0 +1,734 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "common.cuh"
+#include "convert.cuh"
+#include "vecdotq.cuh"
+
+#include 
+
+#define FATTN_KQ_STRIDE       256
+#define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+typedef void (* fattn_kernel_t)(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3);
+
+typedef half (*vec_dot_KQ_f16_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+typedef float (*vec_dot_KQ_f32_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_0;
+        const int shift = k_KQ & (QI8_1/2);
+
+        const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+        }
+    }
+
+    return sum;
+}
+
+template
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
+            sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            const float sumid4d8   =  __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+            const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+            sum += (T) (sumid4d8 + m4s8scaled);
+        }
+    }
+
+    return sum;
+}
+
+template
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_0;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+        v |= (vh <<  4) & 0x00000010; // 0 ->  4
+        v |= (vh << 11) & 0x00001000; // 1 -> 12
+        v |= (vh << 18) & 0x00100000; // 2 -> 20
+        v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+            sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+        }
+    }
+
+    return sum;
+}
+
+template
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_1;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+        v |= (vh <<  4) & 0x00000010; // 0 ->  4
+        v |= (vh << 11) & 0x00001000; // 1 -> 12
+        v |= (vh << 18) & 0x00100000; // 2 -> 20
+        v |= (vh << 25) & 0x10000000; // 3 -> 28
+
+        const int u = Q_q8[k_KQ_0/WARP_SIZE];
+
+        const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+#ifdef FP16_AVAILABLE
+        if (std::is_same::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+
+            const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+            const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
+            sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
+        } else
+#endif // FP16_AVAILABLE
+        {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+
+            const float sumid5d8   =  __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
+            const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+
+            sum += (T) (sumid5d8 + m5s8scaled);
+        }
+    }
+
+    return sum;
+}
+
+template 
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+    const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    T sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const int ib  = k_KQ / QI8_0;
+        const int iqs = k_KQ % QI8_0;
+
+        const int v = get_int_b2(K_q8_0[ib].qs, iqs);
+
+        T Q_d;
+        if (std::is_same::value) {
+            const half2  * Q_ds = (const half2  *) Q_ds_v;
+            Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+        } else {
+            const float2 * Q_ds = (const float2 *) Q_ds_v;
+            Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+        }
+
+        sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+    }
+
+    return sum;
+}
+
+template 
+static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
+
+    const half2 * K_h2 = (const half2 *) K_c;
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same::value) {
+        const half2 * Q_h2 = (const half2 *) Q_v;
+
+        half2 sum2 = make_half2(0.0f, 0.0f);
+
+#pragma unroll
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+            const int k_KQ = k_KQ_0 + threadIdx.x;
+
+            const half2 K_ik = K_h2[k_KQ];
+            sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+        }
+
+        return __low2half(sum2) + __high2half(sum2);
+    }
+#endif // FP16_AVAILABLE
+
+    const float2 * Q_f2 = (const float2 *) Q_v;
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+        const int k_KQ = k_KQ_0 + threadIdx.x;
+
+        const half2 K_ik = K_h2[k_KQ];
+        sum +=  __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
+        sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+    }
+
+    return sum;
+}
+
+template 
+static __device__ __forceinline__ void quantize_q8_1_to_shared(
+    const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
+
+    float vals[sizeof(int)] = {0.0f};
+#pragma unroll
+    for (int l = 0; l < sizeof(int); ++l) {
+        vals[l] = scale * x[4*threadIdx.x + l];
+    }
+
+    float amax = fabsf(vals[0]);
+    float sum  = vals[0];
+#pragma unroll
+    for (int l = 1; l < sizeof(int); ++l) {
+        amax = fmaxf(amax, fabsf(vals[l]));
+        sum += vals[l];
+    }
+#pragma unroll
+    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
+        sum +=             __shfl_xor_sync(0xFFFFFFFF, sum,  mask, 32);
+    }
+
+    const float d = amax / 127;
+    int q32 = 0;
+    int8_t * q8 = (int8_t *) &q32;
+
+    if (d != 0.0f) {
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            q8[l] = roundf(vals[l] / d);
+        }
+    }
+
+    yq32[threadIdx.x] = q32;
+    if (threadIdx.x % QI8_1 == 0) {
+        if (std::is_same::value) {
+            ((half2  *) yds)[threadIdx.x/QI8_1] =  make_half2(d, sum);
+        } else {
+            ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
+        }
+    }
+}
+
+typedef half  (*dequantize_1_f16_t)(const void *, const int64_t);
+typedef float (*dequantize_1_f32_t)(const void *, const int64_t);
+
+template 
+static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int64_t ib    =  i          /  QK4_0;
+    const int     iqs   =  i          % (QK4_0/2);
+    const int     shift = (i % QK4_0) / (QK4_0/2);
+
+    const T   d  = x[ib].d;
+    const int q0 = x[ib].qs[iqs];
+    const int q  = ((q0 >> (4*shift)) & 0x0F) - 8;
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template 
+static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int64_t ib    =  i          /  QK4_1;
+    const int     iqs   =  i          % (QK4_1/2);
+    const int     shift = (i % QK4_1) / (QK4_1/2);
+
+    const half2 dm = x[ib].dm;
+    const int   q0 = x[ib].qs[iqs];
+    const int   q  = ((q0 >> (4*shift)) & 0x0F);
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same::value) {
+        return __low2half(dm)*((half) q) + __high2half(dm);
+    }
+#endif // FP16_AVAILABLE
+
+    return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template 
+static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int64_t ib    =  i          /  QK5_0;
+    const int     idq   =  i          %  QK5_0;
+    const int     iqs   =  i          % (QK5_0/2);
+    const int     shift = (i % QK5_0) / (QK5_0/2);
+
+    const T   d   = x[ib].d;
+    const int ql0 = x[ib].qs[iqs];
+    const int qh0 = get_int_b2(x[ib].qh, 0);
+    const int ql  = ((ql0 >> (4*shift)) & 0x0F);
+    const int qh  = ((qh0 >> idq) << 4) & 0x10;
+    const int q   = (ql | qh) - 16;
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template 
+static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int64_t ib    =  i          /  QK5_1;
+    const int     idq   =  i          %  QK5_1;
+    const int     iqs   =  i          % (QK5_1/2);
+    const int     shift = (i % QK5_1) / (QK5_1/2);
+
+    const half2 dm  = x[ib].dm;
+    const int   ql0 = x[ib].qs[iqs];
+    const int   qh0 = get_int_b4(x[ib].qh, 0);
+    const int   ql  = ((ql0 >> (4*shift)) & 0x0F);
+    const int   qh  = ((qh0 >> idq) << 4) & 0x10;
+    const int   q   = (ql | qh);
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same::value) {
+        return __low2half(dm)*((half) q) + __high2half(dm);
+    }
+#endif // FP16_AVAILABLE
+
+    return __low2float(dm)*((float) q) + __high2float(dm);
+}
+
+template 
+static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const int64_t ib  = i / QK8_0;
+    const int     iqs = i % QK8_0;
+
+    const T   d = x[ib].d;
+    const int q = x[ib].qs[iqs];
+
+#ifdef FP16_AVAILABLE
+    if (std::is_same::value) {
+        return ((half) d)*((half) q);
+    }
+#endif // FP16_AVAILABLE
+
+    return ((float) d)*((float) q);
+}
+
+template 
+static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) {
+    const half * x = (const half *) vx;
+
+    return x[i];
+}
+
+template 
+constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 :
+        nullptr;
+}
+
+template 
+constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 :
+        nullptr;
+}
+
+constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) {
+    return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 :
+        type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 :
+        type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 :
+        type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 :
+        type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 :
+        type_V == GGML_TYPE_F16 ? dequantize_1_f16 :
+        nullptr;
+}
+
+constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
+    return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 :
+        type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 :
+        type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 :
+        type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 :
+        type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 :
+        type_V == GGML_TYPE_F16 ? dequantize_1_f16 :
+        nullptr;
+}
+
+template // D == head size
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_combine_results(
+        const float  * __restrict__ VKQ_parts,
+        const float2 * __restrict__ VKQ_meta,
+        float * __restrict__ dst) {
+    VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
+    VKQ_meta  += parallel_blocks   * gridDim.y*blockIdx.x;
+    dst       +=                 D * gridDim.y*blockIdx.x;
+
+    const int tid = threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float2 meta[parallel_blocks];
+    if (tid < 2*parallel_blocks) {
+        ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
+    }
+
+    __syncthreads();
+
+    float kqmax = meta[0].x;
+#pragma unroll
+    for (int l = 1; l < parallel_blocks; ++l) {
+        kqmax = max(kqmax, meta[l].x);
+    }
+
+    float VKQ_numerator   = 0.0f;
+    float VKQ_denominator = 0.0f;
+#pragma unroll
+    for (int l = 0; l < parallel_blocks; ++l) {
+        const float diff = meta[l].x - kqmax;
+        const float KQ_max_scale = expf(diff);
+        const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
+        *((uint32_t *) &KQ_max_scale) &= ftz_mask;
+
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
+        VKQ_denominator += KQ_max_scale * meta[l].y;
+    }
+
+    dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
+}
+
+static void on_no_fattn_vec_case(const int D) {
+    if (D == 64) {
+        fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
+        fprintf(stderr, "By default only f16 KV cache is supported.\n");
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
+        GGML_ABORT("fatal error");
+    } else if (D == 128) {
+        fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
+        fprintf(stderr, "Supported combinations:\n");
+        fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n");
+        fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n");
+        fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n");
+        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
+        GGML_ABORT("fatal error");
+    } else {
+        fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+        fprintf(stderr, "Only f16 is supported.\n");
+        GGML_ABORT("fatal error");
+    }
+}
+
+template 
+void launch_fattn(
+    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
+    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+) {
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const ggml_tensor * mask = dst->src[3];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t main_stream = ctx.stream();
+
+    ggml_cuda_pool_alloc   K_f16(pool);
+    ggml_cuda_pool_alloc   V_f16(pool);
+    ggml_cuda_pool_alloc  dst_tmp(pool);
+    ggml_cuda_pool_alloc dst_tmp_meta(pool);
+
+    char * K_data = (char *) K->data;
+    size_t nb11 = K->nb[1];
+    size_t nb12 = K->nb[2];
+    size_t nb13 = K->nb[3];
+
+    char * V_data = (char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
+
+    if (need_f16_K && K->type != GGML_TYPE_F16) {
+        K_f16.alloc(ggml_nelements(K));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+        to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+        K_data = (char *) K_f16.ptr;
+
+        const size_t bs = ggml_blck_size(K->type);
+        const size_t ts = ggml_type_size(K->type);
+
+        nb11 = nb11*bs*sizeof(half)/ts;
+        nb12 = nb12*bs*sizeof(half)/ts;
+        nb13 = nb13*bs*sizeof(half)/ts;
+    }
+
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        V_f16.alloc(ggml_nelements(V));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+        to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+        V_data = (char *) V_f16.ptr;
+
+        const size_t bs = ggml_blck_size(V->type);
+        const size_t ts = ggml_type_size(V->type);
+
+        nb21 = nb21*bs*sizeof(half)/ts;
+        nb22 = nb22*bs*sizeof(half)/ts;
+        nb23 = nb23*bs*sizeof(half)/ts;
+    }
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    const dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
+    const int  shmem = 0;
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    fattn_kernel<<>>(
+        (const char *) Q->data,
+        K_data,
+        V_data,
+        mask ? ((const char *) mask->data) : nullptr,
+        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+        scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+        K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+        mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+        Q->nb[1], Q->nb[2], Q->nb[3],
+        nb11, nb12, nb13,
+        nb21, nb22, nb23,
+        KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+    );
+    CUDA_CHECK(cudaGetLastError());
+
+    if ((parallel_blocks) == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results
+        <<>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}
diff --git a/llama/ggml-cuda/fattn-tile-f16.cu b/llama/ggml-cuda/fattn-tile-f16.cu
new file mode 100644
index 000000000..72d265ef2
--- /dev/null
+++ b/llama/ggml-cuda/fattn-tile-f16.cu
@@ -0,0 +1,379 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f16.cuh"
+
+#define FATTN_KQ_STRIDE_TILE_F16 64
+
+template // D == head size
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_tile_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifdef FP16_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *)  mask + ne11*ic0;
+
+    const int stride_KV2 = nb11 / sizeof(half2);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+
+    __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
+    half2 * KQ2 = (half2 *) KQ;
+
+    __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
+
+    half kqmax[ncols/nwarps];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        kqmax[j0/nwarps] = -HALF_MAX_HALF;
+    }
+    half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
+
+    half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
+
+    // Convert Q to half2 and store in registers:
+    __shared__ half2 Q_h2[ncols][D/2];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
+            Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+        }
+    }
+
+    __syncthreads();
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        half kqmax_new[ncols/nwarps];
+#pragma unroll
+        for (int j = 0; j < ncols/nwarps; ++j) {
+            kqmax_new[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+                const int k_KQ = k_KQ_0 + threadIdx.x;
+
+                KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
+            }
+        }
+
+        __syncthreads();
+
+        half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
+
+#pragma unroll
+        for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
+            half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
+            half2 Q_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+                const int i_KQ = i_KQ_0 + threadIdx.x;
+
+                K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
+            }
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
+            }
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+#pragma unroll
+                for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                    sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
+            const int i_KQ = i_KQ_0 + threadIdx.x;
+
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                half sum;
+                if (use_logit_softcap) {
+                    const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                    sum = logit_softcap * tanhf(tmp.x + tmp.y);
+                } else {
+                    sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                }
+                sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+                kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
+
+                KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
+            const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
+            kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
+
+#pragma unroll
+            for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
+                const half2 val = h2exp(diff);
+                kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
+                KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
+            const int k = k0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i];
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
+            half2  V_k[(D/2)/WARP_SIZE][2];
+            half2 KQ_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
+                V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                const int j = j0 + threadIdx.y;
+
+                KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+#pragma unroll
+                for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                    VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
+                    VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
+                }
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
+        const int j_VKQ = j_VKQ_0 + threadIdx.y;
+
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
+        half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
+        kqsum_j = warp_reduce_sum((float)kqsum_j);
+
+#pragma unroll
+        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
+            const int i0 = i00 + 2*threadIdx.x;
+
+            half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            if (parallel_blocks == 1) {
+                dst_val /= __half2half2(kqsum_j);
+            }
+            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] =  __low2float(dst_val);
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
+        }
+
+        if (parallel_blocks != 1 && threadIdx.x == 0) {
+            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        }
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template 
+void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        default: {
+            GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
+    }
+}
+
+void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const int32_t precision = KQV->op_params[3];
+    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] <= 16) {
+        constexpr int cols_per_block = 16;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f16_64_128(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f16_64_128(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 32;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f16_64_128(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f16_64_128(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_f16_64_128(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_f16_64_128(ctx, dst);
+    }
+}
diff --git a/llama/ggml-cuda/fattn-tile-f16.cuh b/llama/ggml-cuda/fattn-tile-f16.cuh
new file mode 100644
index 000000000..4a3965ed3
--- /dev/null
+++ b/llama/ggml-cuda/fattn-tile-f16.cuh
@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/fattn-tile-f32.cu b/llama/ggml-cuda/fattn-tile-f32.cu
new file mode 100644
index 000000000..3be1c7a62
--- /dev/null
+++ b/llama/ggml-cuda/fattn-tile-f32.cu
@@ -0,0 +1,375 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f32.cuh"
+
+#define FATTN_KQ_STRIDE_TILE_F32 32
+
+template // D == head size
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_tile_ext_f32(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifndef FLASH_ATTN_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FLASH_ATTN_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float2 * Q_f2  = (const float2 *) (Q    + nb02* blockIdx.y              + nb01*ic0);
+    const half2  * K_h2  = (const half2  *) (K    + nb12*(blockIdx.y / gqa_ratio));
+    const half2  * V_h2  = (const half2  *) (V    + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half   * maskh = (const half   *)  mask + ne11*ic0;
+
+    const int stride_KV2 = nb11 / sizeof(half2);
+
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+
+    __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
+
+    __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
+    float2 * KV_tmp2 = (float2 *) KV_tmp;
+
+    float kqmax[ncols/nwarps];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        kqmax[j0/nwarps] = -FLT_MAX/2.0f;
+    }
+    float kqsum[ncols/nwarps] = {0.0f};
+
+    float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
+
+    // Convert Q to half2 and store in registers:
+    __shared__ float Q_f[ncols][D];
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
+            float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
+            Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
+            Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
+        }
+    }
+
+    __syncthreads();
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        float kqmax_new[ncols/nwarps];
+#pragma unroll
+        for (int j = 0; j < ncols/nwarps; ++j) {
+            kqmax_new[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
+                const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
+                KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] =  __low2float(tmp);
+                KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
+            }
+        }
+
+        __syncthreads();
+
+        float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
+
+#pragma unroll
+        for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
+            float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
+            float Q_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+                const int i_KQ = i_KQ_0 + threadIdx.x;
+
+                K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
+            }
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
+            }
+
+#pragma unroll
+            for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+#pragma unroll
+                for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                    sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
+            const int i_KQ = i_KQ_0 + threadIdx.x;
+
+#pragma unroll
+            for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
+                const int j_KQ = j_KQ_0 + threadIdx.y;
+
+                if (use_logit_softcap) {
+                    sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                }
+
+                sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+                kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+
+                KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
+            const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
+            kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
+
+            float kqsum_add = 0.0f;
+#pragma unroll
+            for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
+                const float val = expf(diff);
+                kqsum_add += val;
+                KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
+            }
+            kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
+                VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
+            const int k = k0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                KV_tmp2[k*(D/2) + i].x =  __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+                KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
+            float2 V_k[(D/2)/WARP_SIZE];
+            float  KQ_k[ncols/nwarps];
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                const int j = j0 + threadIdx.y;
+
+                KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+#pragma unroll
+                for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+                    VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
+                    VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
+                }
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
+        const int j_VKQ = j_VKQ_0 + threadIdx.y;
+
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+
+        float kqsum_j = kqsum[j_VKQ_0/nwarps];
+        kqsum_j = warp_reduce_sum(kqsum_j);
+
+#pragma unroll
+        for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
+            const int i0 = i00 + 2*threadIdx.x;
+
+            float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
+            if (parallel_blocks == 1) {
+                dst_val.x /= kqsum_j;
+                dst_val.y /= kqsum_j;
+            }
+            const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
+            dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
+        }
+
+        if (parallel_blocks != 1 && threadIdx.x == 0) {
+            dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
+        }
+    }
+}
+
+template 
+void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        } break;
+        default: {
+            GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
+    }
+}
+
+void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q = dst->src[0];
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] <= 16) {
+        constexpr int cols_per_block = 16;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f32_64_128(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f32_64_128(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 32;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f32_64_128(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f32_64_128(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_f32_64_128(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_f32_64_128(ctx, dst);
+    }
+}
diff --git a/llama/ggml-cuda/fattn-tile-f32.cuh b/llama/ggml-cuda/fattn-tile-f32.cuh
new file mode 100644
index 000000000..8a5eef471
--- /dev/null
+++ b/llama/ggml-cuda/fattn-tile-f32.cuh
@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/fattn-vec-f16.cuh b/llama/ggml-cuda/fattn-vec-f16.cuh
new file mode 100644
index 000000000..334a05c37
--- /dev/null
+++ b/llama/ggml-cuda/fattn-vec-f16.cuh
@@ -0,0 +1,467 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+template // D == head size
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifdef FP16_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K);
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb02* blockIdx.y              + nb01*ic0;
+    K += nb12*(blockIdx.y / gqa_ratio);
+    V += nb22*(blockIdx.y / gqa_ratio);
+
+    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ half KQ[ncols*D];
+    half2 * KQ2 = (half2 *) KQ;
+
+    half kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -HALF_MAX_HALF;
+    }
+    half kqsum[ncols] = {0.0f};
+
+    __shared__ half kqmax_shared[ncols][WARP_SIZE];
+    __shared__ half kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
+    }
+    __syncthreads();
+
+    // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
+    half2  Q_h2[ncols][D/(2*WARP_SIZE)];
+    int   Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)];
+    half2  Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+    if (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int   * tmp_q_i32 = (int   *) &KQ[j*D];
+            half2 * tmp_q_ds  = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    tmp_q_i32[i] = 0;
+                }
+                if (threadIdx.x < D/QK8_1) {
+                    tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f);
+                }
+                continue;
+            }
+
+            const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int   * tmp_q_i32 = (int   *) &KQ[j*D];
+            half2 * tmp_q_ds  = (half2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+                Q_ds[j][i0/WARP_SIZE]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        __syncthreads();
+    } else {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
+            }
+        }
+    }
+
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -HALF_MAX_HALF;
+    }
+
+    half2 VKQ[ncols] = {{0.0f, 0.0f}};
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
+        // see https://github.com/ggerganov/llama.cpp/pull/7061 .
+        // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
+        half kqmax_new = kqmax[0];
+        half kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum((float)sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap*tanhf(sum);
+                }
+
+                sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+
+                if (ncols == 1) {
+                    kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);
+                } else {
+                    kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
+                }
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
+
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const half val = hexp(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
+
+            VKQ[j] *= __half2half2(KQ_max_scale);
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k0 = 0; k0 < D; k0 += 2) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) {
+                break;
+            }
+
+            half2 V_k;
+            reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid);
+            reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid);
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_k*KQ2[j*(D/2) + k0/2];
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum((float)kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
+
+        half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+    }
+
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+}
+
+template 
+void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int nwarps = D/WARP_SIZE;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16;
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template 
+void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * K   = dst->src[1];
+    const ggml_tensor * V   = dst->src[2];
+
+    const int32_t precision = KQV->op_params[3];
+    GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+
+    GGML_ASSERT(K->type == type_K);
+    GGML_ASSERT(V->type == type_V);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block  = 1;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 2) {
+        constexpr int cols_per_block  = 2;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 4) {
+        constexpr int cols_per_block  = 4;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8) {
+        constexpr int cols_per_block  = 8;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block  = 8;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst);
+    }
+}
+
+#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V)                          \
+    template void ggml_cuda_flash_attn_ext_vec_f16_case                     \
+    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/fattn-vec-f32.cuh b/llama/ggml-cuda/fattn-vec-f32.cuh
new file mode 100644
index 000000000..0bb230004
--- /dev/null
+++ b/llama/ggml-cuda/fattn-vec-f32.cuh
@@ -0,0 +1,445 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+template // D == head size
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(D, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_vec_ext_f32(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K);
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+    constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
+
+    const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
+    const int ip  =  blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb02* blockIdx.y              + nb01*ic0;
+    K += nb12*(blockIdx.y / gqa_ratio);
+    V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
+    const half * maskh = (const half   *)  mask + ne11*ic0;
+
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+
+    static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+    constexpr int nwarps = D / WARP_SIZE;
+    const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    __builtin_assume(tid < D);
+
+    __shared__ float KQ[ncols*D];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ[j*D + tid] = -FLT_MAX/2.0f;
+    }
+
+    float kqmax[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqmax[j] = -FLT_MAX/2.0f;
+    }
+    float kqsum[ncols] = {0.0f};
+
+    __shared__ float kqmax_shared[ncols][WARP_SIZE];
+    __shared__ float kqsum_shared[ncols][WARP_SIZE];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (threadIdx.y == 0) {
+            kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
+            kqsum_shared[j][threadIdx.x] = 0.0f;
+        }
+    }
+    __syncthreads();
+
+    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+    float2  Q_f2[ncols][D/(2*WARP_SIZE)];
+    int    Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
+    float2  Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
+    if (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 2 && ic0 + j >= ne01) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    tmp_q_i32[i] = 0;
+                }
+                if (threadIdx.x < D/QK8_1) {
+                    tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
+                }
+                continue;
+            }
+
+            const float * Q_f = (const float *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            float2 * tmp_q_ds  = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
+                Q_ds[j][i0/WARP_SIZE]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        __syncthreads();
+    } else {
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                Q_f2[j][i0/WARP_SIZE]    = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
+                Q_f2[j][i0/WARP_SIZE].x *= scale;
+                Q_f2[j][i0/WARP_SIZE].y *= scale;
+            }
+        }
+    }
+
+    float VKQ[ncols] = {0.0f};
+
+    const int k_start = parallel_blocks == 1 ? 0 : ip*D;
+    for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+
+        float kqmax_new_arr[ncols];
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            kqmax_new_arr[j] = kqmax[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
+            const int i_KQ = i_KQ_0 + threadIdx.y;
+
+            if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) {
+                break;
+            }
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap*tanhf(sum);
+                }
+
+                sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+
+                kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
+
+                if (threadIdx.x == 0) {
+                    KQ[j*D + i_KQ] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_new_arr[j];
+
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = kqmax_new_j;
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const float val = expf(KQ[j*D + tid] - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale + val;
+            KQ[j*D + tid] = val;
+
+            VKQ[j] *= KQ_max_scale;
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int k = 0; k < D; ++k) {
+            if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) {
+                break;
+            }
+
+            const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                VKQ[j] += V_ki*KQ[j*D + k];
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        if (threadIdx.x == 0) {
+            kqsum_shared[j][threadIdx.y] = kqsum[j];
+        }
+    }
+
+    __syncthreads();
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 2 && ic0 + j_VKQ >= ne01) {
+            break;
+        }
+
+        kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
+        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+
+        float dst_val = VKQ[j_VKQ];
+        if (parallel_blocks == 1) {
+            dst_val /= kqsum[j_VKQ];
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+        dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
+    }
+
+    if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
+    }
+}
+
+template 
+void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int nwarps = D/WARP_SIZE;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32;
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+}
+
+template 
+void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * K   = dst->src[1];
+    const ggml_tensor * V   = dst->src[2];
+
+    GGML_ASSERT(K->type == type_K);
+    GGML_ASSERT(V->type == type_V);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block  = 1;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 2) {
+        constexpr int cols_per_block  = 2;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 4) {
+        constexpr int cols_per_block  = 4;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8) {
+        constexpr int cols_per_block  = 8;
+        constexpr int parallel_blocks = 4;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block  = 8;
+    constexpr int parallel_blocks = 1;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst);
+    }
+}
+
+#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V)                          \
+    template void ggml_cuda_flash_attn_ext_vec_f32_case                     \
+    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0);
+
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16);
+
+extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/fattn-wmma-f16.cuh b/llama/ggml-cuda/fattn-wmma-f16.cuh
new file mode 100644
index 000000000..d82984f40
--- /dev/null
+++ b/llama/ggml-cuda/fattn-wmma-f16.cuh
@@ -0,0 +1,569 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+#ifdef FP16_MMA_AVAILABLE
+#include 
+#endif // FP16_MMA_AVAILABLE
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#ifdef FP16_MMA_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment frag_a_K;
+    typedef nvcuda::wmma::fragment frag_a_V;
+    typedef nvcuda::wmma::fragment frag_b;
+    typedef nvcuda::wmma::fragment                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
+
+    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+                    if (use_logit_softcap) {
+                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                    }
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+                    if (use_logit_softcap) {
+                        // There is no dedicated tangens hyperbolicus function for half2.
+                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                    }
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // FP16_MMA_AVAILABLE
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template 
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    constexpr int nwarps = 4;
+
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+        return;
+    }
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    }
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+}
+
+#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
+    template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \
+    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
+// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE(128,  8, half);
+extern DECL_FATTN_WMMA_F16_CASE(256,  8, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+
+extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
+extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
diff --git a/llama/ggml-cuda/fattn.cu b/llama/ggml-cuda/fattn.cu
new file mode 100644
index 000000000..4828e9d84
--- /dev/null
+++ b/llama/ggml-cuda/fattn.cu
@@ -0,0 +1,371 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-tile-f16.cuh"
+#include "fattn-tile-f32.cuh"
+#include "fattn-vec-f16.cuh"
+#include "fattn-vec-f32.cuh"
+#include "fattn-wmma-f16.cuh"
+#include "fattn.cuh"
+
+#include 
+
+static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+
+    if (prec != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                case 256:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                // case 256:
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                //     break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 80:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 112:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    switch (Q->ne[0]) {
+        case 64:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            break;
+        case 80:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            break;
+        case 96:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            break;
+        case 112:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            break;
+        case 128:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            break;
+        case 256:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+#define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
+    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
+        ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \
+        return;                                                             \
+    }                                                                       \
+
+static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[0];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16)
+
+    FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+    on_no_fattn_vec_case(Q->ne[0]);
+}
+
+#define FATTN_VEC_F32_CASE(D, type_K, type_V)                               \
+    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
+        ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \
+        return;                                                             \
+    }                                                                       \
+
+static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[0];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16,  GGML_TYPE_F16)
+
+    FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#else
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+
+    FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
+    FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+    on_no_fattn_vec_case(Q->ne[0]);
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    ggml_cuda_set_device(ctx.device);
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+
+    // On AMD the tile kernels perform poorly, use the vec kernel instead:
+    if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
+        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        }
+        return;
+    }
+
+    if (!fast_fp16_available(cc)) {
+        if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+        }
+        return;
+    }
+
+    if (!fp16_mma_available(cc)) {
+        if (Q->ne[1] <= 8) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+        }
+        return;
+    }
+
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+        if (prec == GGML_PREC_DEFAULT) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            return;
+        } else if(Q->ne[0] <= 128) {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            return;
+        }
+    }
+
+    ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+}
diff --git a/llama/ggml-cuda/fattn.cuh b/llama/ggml-cuda/fattn.cuh
new file mode 100644
index 000000000..6947118e1
--- /dev/null
+++ b/llama/ggml-cuda/fattn.cuh
@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/getrows.cu b/llama/ggml-cuda/getrows.cu
new file mode 100644
index 000000000..6cf1e516e
--- /dev/null
+++ b/llama/ggml-cuda/getrows.cu
@@ -0,0 +1,203 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "getrows.cuh"
+#include "dequantize.cuh"
+
+template
+static __global__ void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(src0_row, ib, iqs, v);
+
+    dst_row[iybs + iqs + 0]        = v.x;
+    dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template
+static __global__ void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+    dst_row[i00] = src0_row[i00];
+}
+
+template
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    GGML_ASSERT(ne00 % 2 == 0);
+
+    k_get_rows<<>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    GGML_UNUSED(dst);
+}
+
+template
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
+
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
+
+    k_get_rows_float<<>>(
+            src0_dd, src1_dd, dst_dd,
+            ne00, /*ne01, ne02, ne03,*/
+            /*ne10, ne11,*/ ne12, /*ne13,*/
+            /* s0,*/ s1, s2, s3,
+            /* nb00,*/ nb01, nb02, nb03,
+            s10, s11, s12/*, s13*/);
+
+    GGML_UNUSED(dst);
+}
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+
+    const int32_t * src1_i32 = (const int32_t *) src1_d;
+
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_F32:
+            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_0:
+            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        default:
+            // TODO: k-quants
+            GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+            break;
+    }
+}
diff --git a/llama/ggml-cuda/getrows.cuh b/llama/ggml-cuda/getrows.cuh
new file mode 100644
index 000000000..bbbf482d3
--- /dev/null
+++ b/llama/ggml-cuda/getrows.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_GET_ROWS_BLOCK_SIZE 256
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/ggml-cuda.cu b/llama/ggml-cuda/ggml-cuda.cu
new file mode 100644
index 000000000..0894fdad7
--- /dev/null
+++ b/llama/ggml-cuda/ggml-cuda.cu
@@ -0,0 +1,3318 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-cuda.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-cuda/common.cuh"
+#include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/arange.cuh"
+#include "ggml-cuda/argmax.cuh"
+#include "ggml-cuda/argsort.cuh"
+#include "ggml-cuda/binbcast.cuh"
+#include "ggml-cuda/clamp.cuh"
+#include "ggml-cuda/concat.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
+#include "ggml-cuda/convert.cuh"
+#include "ggml-cuda/count-equal.cuh"
+#include "ggml-cuda/cpy.cuh"
+#include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/diagmask.cuh"
+#include "ggml-cuda/fattn.cuh"
+#include "ggml-cuda/getrows.cuh"
+#include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmq.cuh"
+#include "ggml-cuda/mmv.cuh"
+#include "ggml-cuda/mmvq.cuh"
+#include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/out-prod.cuh"
+#include "ggml-cuda/pad.cuh"
+#include "ggml-cuda/pool2d.cuh"
+#include "ggml-cuda/quantize.cuh"
+#include "ggml-cuda/rope.cuh"
+#include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/sum.cuh"
+#include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/tsembd.cuh"
+#include "ggml-cuda/unary.cuh"
+#include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/wkv6.cuh"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
+    int id = -1; // in case cudaGetDevice fails
+    cudaGetDevice(&id);
+
+    GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
+    GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
+    GGML_LOG_ERROR("  %s\n", stmt);
+    // abort with GGML_ABORT to get a stack trace
+    GGML_ABORT(GGML_CUDA_NAME " error");
+}
+
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+void ggml_cuda_set_device(int device) {
+    int current_device;
+    CUDA_CHECK(cudaGetDevice(¤t_device));
+
+    if (device == current_device) {
+        return;
+    }
+
+    CUDA_CHECK(cudaSetDevice(device));
+}
+
+int ggml_cuda_get_device() {
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    return id;
+}
+
+static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
+    ggml_cuda_set_device(device);
+#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA)
+    auto res = hipMallocManaged(ptr, size);
+    if (res == hipSuccess) {
+        // if error we "need" to know why...
+        CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+    }
+    return res;
+#else
+
+#if !defined(GGML_USE_HIP)
+    cudaError_t err;
+    if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
+    {
+        err = cudaMallocManaged(ptr, size);
+    }
+    else
+    {
+        err = cudaMalloc(ptr, size);
+    }
+    return err;
+#else
+    return cudaMalloc(ptr, size);
+#endif // !defined(GGML_USE_HIP)
+
+#endif
+}
+
+static ggml_cuda_device_info ggml_cuda_init() {
+#ifdef __HIP_PLATFORM_AMD__
+    // Workaround for a rocBLAS bug when using multiple graphics cards:
+    // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
+    rocblas_initialize();
+    CUDA_CHECK(cudaDeviceSynchronize());
+#endif
+
+    ggml_cuda_device_info info = {};
+
+    cudaError_t err = cudaGetDeviceCount(&info.device_count);
+    if (err != cudaSuccess) {
+        GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
+        return info;
+    }
+
+    GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
+
+    int64_t total_vram = 0;
+#ifdef GGML_CUDA_FORCE_MMQ
+    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    yes\n", __func__);
+#else
+    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ:    no\n", __func__);
+#endif // GGML_CUDA_FORCE_MMQ
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
+#else
+    GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
+#endif // GGML_CUDA_FORCE_CUBLAS
+    GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+    for (int id = 0; id < info.device_count; ++id) {
+        int device_vmm = 0;
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+        CUdevice device;
+        CU_CHECK(cuDeviceGet(&device, id));
+        CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
+
+        if (device_vmm) {
+            CUmemAllocationProp alloc_prop = {};
+            alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+            alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+            alloc_prop.location.id = id;
+            CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
+        }
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+        info.devices[id].vmm = !!device_vmm;
+
+        cudaDeviceProp prop;
+        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+
+        info.default_tensor_split[id] = total_vram;
+        total_vram += prop.totalGlobalMem;
+
+        info.devices[id].nsm   = prop.multiProcessorCount;
+        info.devices[id].smpb  = prop.sharedMemPerBlock;
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+        info.devices[id].smpbo = prop.sharedMemPerBlock;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
+#else
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+    }
+
+    for (int id = 0; id < info.device_count; ++id) {
+        info.default_tensor_split[id] /= total_vram;
+    }
+
+    // configure logging to stdout
+    // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+    return info;
+}
+
+const ggml_cuda_device_info & ggml_cuda_info() {
+    static ggml_cuda_device_info info = ggml_cuda_init();
+    return info;
+}
+
+// #define DEBUG_CUDA_MALLOC
+
+// buffer pool for cuda (legacy)
+struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+    static const int MAX_BUFFERS = 256;
+
+    int device;
+    struct ggml_cuda_buffer {
+        void * ptr = nullptr;
+        size_t size = 0;
+    };
+
+    ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
+    size_t pool_size = 0;
+
+    explicit ggml_cuda_pool_leg(int device) :
+        device(device) {
+    }
+
+    ~ggml_cuda_pool_leg() {
+        ggml_cuda_set_device(device);
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cuda_buffer & b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+                CUDA_CHECK(cudaFree(b.ptr));
+                pool_size -= b.size;
+            }
+        }
+        GGML_ASSERT(pool_size == 0);
+    }
+
+    void * alloc(size_t size, size_t * actual_size) override {
+#ifdef DEBUG_CUDA_MALLOC
+        int nnz = 0;
+        size_t max_size = 0;
+#endif
+        size_t best_diff = 1ull << 36;
+        int ibest = -1;
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cuda_buffer& b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+                ++nnz;
+                if (b.size > max_size) max_size = b.size;
+#endif
+                if (b.size >= size) {
+                    size_t diff = b.size - size;
+                    if (diff < best_diff) {
+                        best_diff = diff;
+                        ibest = i;
+                        if (!best_diff) {
+                            void * ptr = b.ptr;
+                            *actual_size = b.size;
+                            b.ptr = nullptr;
+                            b.size = 0;
+                            return ptr;
+                        }
+                    }
+                }
+            }
+        }
+        if (ibest >= 0) {
+            ggml_cuda_buffer& b = buffer_pool[ibest];
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+        void * ptr;
+        size_t look_ahead_size = (size_t) (1.05 * size);
+        look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+        ggml_cuda_set_device(device);
+        CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
+        *actual_size = look_ahead_size;
+        pool_size += look_ahead_size;
+#ifdef DEBUG_CUDA_MALLOC
+        GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
+                           (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
+#endif
+        return ptr;
+    }
+
+    void free(void * ptr, size_t size) override {
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cuda_buffer& b = buffer_pool[i];
+            if (b.ptr == nullptr) {
+                b.ptr = ptr;
+                b.size = size;
+                return;
+            }
+        }
+        GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
+        ggml_cuda_set_device(device);
+        CUDA_CHECK(cudaFree(ptr));
+        pool_size -= size;
+    }
+};
+
+// pool with virtual memory
+#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
+    static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
+
+    int device;
+    CUdeviceptr pool_addr = 0;
+    size_t pool_used = 0;
+    size_t pool_size = 0;
+    size_t granularity;
+
+    explicit ggml_cuda_pool_vmm(int device) :
+        device(device),
+        granularity(ggml_cuda_info().devices[device].vmm_granularity) {
+    }
+
+    ~ggml_cuda_pool_vmm() {
+        if (pool_addr != 0) {
+            CU_CHECK(cuMemUnmap(pool_addr, pool_size));
+            CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
+        }
+    }
+
+    void * alloc(size_t size, size_t * actual_size) override {
+        // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
+        const size_t alignment = 128;
+        size = alignment * ((size + alignment - 1) / alignment);
+
+        size_t avail = pool_size - pool_used;
+
+        if (size > avail) {
+            // round up to the next multiple of the granularity
+            size_t reserve_size = size - avail;
+            reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
+
+            GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+
+            // allocate more physical memory
+            CUmemAllocationProp prop = {};
+            prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+            prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+            prop.location.id = device;
+            CUmemGenericAllocationHandle handle;
+            CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
+
+            // reserve virtual address space (if not already reserved)
+            if (pool_addr == 0) {
+                CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+            }
+
+            // map at the end of the pool
+            CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
+
+            // the memory allocation handle is no longer needed after mapping
+            CU_CHECK(cuMemRelease(handle));
+
+            // set access
+            CUmemAccessDesc access = {};
+            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+            access.location.id = device;
+            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+            CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
+
+            // add to the pool
+            pool_size += reserve_size;
+
+            //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+            //       device, (unsigned long long) (pool_size/1024/1024),
+            //       (unsigned long long) (reserve_size/1024/1024));
+        }
+
+        GGML_ASSERT(pool_addr != 0);
+
+        void * ptr = (void *) (pool_addr + pool_used);
+        *actual_size = size;
+        pool_used += size;
+
+#ifdef DEBUG_CUDA_MALLOC
+        printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+        return ptr;
+    }
+
+    void free(void * ptr, size_t size) override {
+#ifdef DEBUG_CUDA_MALLOC
+        printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+        pool_used -= size;
+
+        // all deallocations must be in reverse order of the allocations
+        GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
+    }
+};
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+
+std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+    if (ggml_cuda_info().devices[device].vmm) {
+        return std::unique_ptr(new ggml_cuda_pool_vmm(device));
+    }
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+    return std::unique_ptr(new ggml_cuda_pool_leg(device));
+}
+
+// cuda buffer
+
+struct ggml_backend_cuda_buffer_context {
+    int device;
+    void * dev_ptr = nullptr;
+    std::string name;
+
+    ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
+        device(device), dev_ptr(dev_ptr),
+        name(GGML_CUDA_NAME + std::to_string(device)) {
+    }
+
+    ~ggml_backend_cuda_buffer_context() {
+        CUDA_CHECK(cudaFree(dev_ptr));
+    }
+};
+
+static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+    delete ctx;
+
+    // TODO: this needs to be freed in cuda and hipblas backends because
+    // the cuda backend implementation compiled with msvc
+    free(buffer);
+}
+
+static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;
+}
+
+static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+    return ctx->dev_ptr;
+}
+
+static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    if (tensor->view_src != NULL) {
+        assert(tensor->view_src->buffer->buft == buffer->buft);
+        return;
+    }
+
+    if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+        // initialize padding to 0 to avoid possible NaN values
+        size_t original_size = ggml_nbytes(tensor);
+        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+        if (padded_size > original_size) {
+            ggml_cuda_set_device(ctx->device);
+            CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
+        }
+    }
+}
+
+static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_cuda(src->buffer)) {
+        ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
+        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
+        if (src_ctx->device == dst_ctx->device) {
+            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
+        } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+            return false;
+#else
+            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
+#endif
+        }
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaDeviceSynchronize());
+    CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size));
+    CUDA_CHECK(cudaDeviceSynchronize());
+}
+
+static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
+    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cuda_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// cuda buffer type
+struct ggml_backend_cuda_buffer_type_context {
+    int device;
+    std::string name;
+};
+
+static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+    return ctx->name.c_str();
+}
+
+static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+    ggml_cuda_set_device(buft_ctx->device);
+
+    void * dev_ptr;
+    cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+        GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
+
+    return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 128;
+
+    GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    size_t size = ggml_nbytes(tensor);
+    int64_t ne0 = tensor->ne[0];
+
+    if (ggml_is_quantized(tensor->type)) {
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+    }
+
+    return size;
+
+    GGML_UNUSED(buft);
+}
+
+static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_cuda_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
+    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    if (device >= ggml_backend_cuda_get_device_count()) {
+        return nullptr;
+    }
+
+    static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
+
+    static bool ggml_backend_cuda_buffer_type_initialized = false;
+
+    if (!ggml_backend_cuda_buffer_type_initialized) {
+        for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) {
+            ggml_backend_cuda_buffer_types[i] = {
+                /* .iface    = */ ggml_backend_cuda_buffer_type_interface,
+                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i),
+                /* .context  = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
+            };
+        }
+        ggml_backend_cuda_buffer_type_initialized = true;
+    }
+
+    return &ggml_backend_cuda_buffer_types[device];
+}
+
+// cuda split buffer
+
+static int64_t get_row_rounding(const std::array & tensor_split) {
+    int64_t row_rounding = 0;
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+            continue;
+        }
+
+        const int cc = ggml_cuda_info().devices[id].cc;
+        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
+    }
+    return row_rounding;
+}
+
+static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) {
+    const int64_t nrows = ggml_nrows(tensor);
+    const int64_t rounding = get_row_rounding(tensor_split);
+
+    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
+    *row_low -= *row_low % rounding;
+
+    if (id == ggml_backend_cuda_get_device_count() - 1) {
+        *row_high = nrows;
+    } else {
+        *row_high = nrows*tensor_split[id + 1];
+        *row_high -= *row_high % rounding;
+    }
+}
+
+static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
+}
+
+struct ggml_backend_cuda_split_buffer_type_context {
+    int main_device;
+    std::array tensor_split;
+    std::string name;
+};
+
+struct ggml_backend_cuda_split_buffer_context {
+    ~ggml_backend_cuda_split_buffer_context() {
+        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+            for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
+                for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+                    if (extra->events[id][is] != nullptr) {
+                        CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+                    }
+                }
+                if (extra->data_device[id] != nullptr) {
+                    CUDA_CHECK(cudaFree(extra->data_device[id]));
+                }
+            }
+            delete extra;
+        }
+    }
+
+    std::vector tensor_extras;
+};
+
+
+static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+    delete ctx;
+}
+
+static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
+    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
+    return (void *)0x1000;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+
+    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+    const int64_t ne0 = tensor->ne[0];
+
+    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+    ctx->tensor_extras.push_back(extra);
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+
+        // FIXME: do not crash if cudaMalloc fails
+        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
+        ggml_cuda_set_device(id);
+        char * buf;
+        CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
+
+        // set padding to 0 to avoid possible NaN values
+        if (size > original_size) {
+            CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
+        }
+
+        extra->data_device[id] = buf;
+
+        for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+        }
+    }
+    tensor->extra = extra;
+}
+
+static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    // split tensors must always be set in their entirety at once
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+    const int64_t ne0 = tensor->ne[0];
+    const size_t nb1 = tensor->nb[1];
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        const size_t offset_split = row_low*nb1;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+
+        const char * buf_host = (const char *)data + offset_split;
+        CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+    }
+}
+
+static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    // split tensors must always be set in their entirety at once
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+    const int64_t ne0 = tensor->ne[0];
+    const size_t nb1 = tensor->nb[1];
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        const size_t offset_split = row_low*nb1;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+
+        char * buf_host = (char *)data + offset_split;
+        CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+    }
+}
+
+static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    GGML_UNUSED(buffer);
+    GGML_UNUSED(value);
+}
+
+static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_cuda_split_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// cuda split buffer type
+
+static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+
+    return ctx->name.c_str();
+}
+
+static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
+    // instead, we allocate them for each tensor separately in init_tensor
+    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
+    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
+    ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
+
+    return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 128;
+
+    GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+
+    size_t total_size = 0;
+
+    const int64_t ne0 = tensor->ne[0];
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
+
+        total_size += ggml_nbytes_split(tensor, nrows_split);
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+    }
+
+    return total_size;
+}
+
+static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
+
+    GGML_UNUSED(buft);
+}
+
+static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_cuda_split_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cuda_split_buffer_type_get_alignment,
+    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+    /* .get_alloc_size   = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
+    /* .is_host          = */ ggml_backend_cuda_split_buffer_type_is_host,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    static std::map>, struct ggml_backend_buffer_type> buft_map;
+
+    std::array tensor_split_arr = {};
+
+    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
+    if (all_zero) {
+        tensor_split_arr = ggml_cuda_info().default_tensor_split;
+    } else {
+        float split_sum = 0.0f;
+        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+            tensor_split_arr[i] = split_sum;
+            split_sum += tensor_split[i];
+        }
+        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+            tensor_split_arr[i] /= split_sum;
+        }
+    }
+
+    auto it = buft_map.find({main_device, tensor_split_arr});
+    if (it != buft_map.end()) {
+        return &it->second;
+    }
+    auto * ctx = new ggml_backend_cuda_split_buffer_type_context{
+        main_device,
+        tensor_split_arr,
+        GGML_CUDA_NAME + std::to_string(main_device) + "_Split",
+    };
+
+    struct ggml_backend_buffer_type buft {
+        /* .iface   = */ ggml_backend_cuda_split_buffer_type_interface,
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),
+        /* .context = */ ctx,
+    };
+
+    auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
+    return &result.first->second;
+}
+
+// host buffer type
+
+static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+    return GGML_CUDA_NAME "_Host";
+
+    GGML_UNUSED(buft);
+}
+
+static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    CUDA_CHECK(cudaFreeHost(buffer->context));
+}
+
+static void * ggml_cuda_host_malloc(size_t size) {
+    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+        return nullptr;
+    }
+
+    void * ptr = nullptr;
+    cudaError_t err = cudaMallocHost((void **) &ptr, size);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+        GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
+                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
+        return nullptr;
+    }
+
+    return ptr;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * ptr = ggml_cuda_host_malloc(size);
+
+    if (ptr == nullptr) {
+        // fallback to cpu buffer
+        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+    }
+
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
+    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+    return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_cuda_host_buffer_type_name,
+            /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+        },
+        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),
+        /* .context  = */ nullptr,
+    };
+
+    return &ggml_backend_cuda_buffer_type_host;
+}
+
+//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
+//    return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
+//}
+
+/// kernels
+
+typedef void (*ggml_cuda_op_mul_mat_t)(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+static cudaError_t ggml_cuda_cpy_tensor_2d(
+    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
+    const char * src_ptr = (const char *) src->data;
+    char       * dst_ptr = (char       *) dst;
+
+    const int64_t ne0 = src->ne[0];
+    const int64_t nb0 = src->nb[0];
+    const int64_t nb1 = src->nb[1];
+    const int64_t nb2 = src->nb[2];
+    const int64_t nb3 = src->nb[3];
+    const enum ggml_type type = src->type;
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    const int64_t i1_diff = i1_high - i1_low;
+
+    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
+    } else if (nb0 == ts) {
+        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
+    } else {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
+            if (r != cudaSuccess) {
+                return r;
+            }
+        }
+        return cudaSuccess;
+    }
+}
+
+static void ggml_cuda_op_mul_mat_cublas(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src0_dd_i  != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_dd_i   != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t row_diff = row_high - row_low;
+
+    int id = ggml_cuda_get_device();
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // ldc == nrows of the matrix that cuBLAS writes into
+    int64_t ldc = id == ctx.device ? ne0 : row_diff;
+
+    const int compute_capability = ggml_cuda_info().devices[id].cc;
+
+    if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
+        ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id));
+        if (src0->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = row_diff*ne00;
+            src0_as_f16.alloc(ne);
+            to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
+        }
+        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
+
+        ggml_cuda_pool_alloc src1_as_f16(ctx.pool(id));
+        if (src1->type != GGML_TYPE_F16) {
+            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+            GGML_ASSERT(to_fp16_cuda != nullptr);
+            size_t ne = src1_ncols*ne10;
+            src1_as_f16.alloc(ne);
+            to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
+        }
+        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
+        ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols);
+
+        const half alpha_f16 = 1.0f;
+        const half beta_f16 = 0.0f;
+
+        cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+        if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
+            cu_compute_type = CUBLAS_COMPUTE_32F;
+        }
+
+        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+        CUBLAS_CHECK(
+            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
+                                src1_ptr,       CUDA_R_16F, ne10,
+                    &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
+                    cu_compute_type,
+                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+        to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+    } else {
+        ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id));
+        ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id));
+
+        if (src0->type != GGML_TYPE_F32) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
+            src0_ddq_as_f32.alloc(row_diff*ne00);
+            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
+        }
+        if (src1->type != GGML_TYPE_F32) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
+            src1_ddq_as_f32.alloc(src1_ncols*ne10);
+            to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+        }
+
+        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
+
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+
+        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+        CUBLAS_CHECK(
+            cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                    row_diff, src1_ncols, ne10,
+                    &alpha, src0_ddf_i,  ne00,
+                            src1_ddf1_i, ne10,
+                    &beta,  dst_dd_i,    ldc));
+    }
+
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_padded_row_size);
+}
+
+static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
+    static bool peer_access_enabled = false;
+
+    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+    if (peer_access_enabled == enable_peer_access) {
+        return;
+    }
+
+#ifdef NDEBUG
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        ggml_cuda_set_device(id);
+        CUDA_CHECK(cudaDeviceSynchronize());
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        ggml_cuda_set_device(id);
+
+        for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
+            if (id == id_other) {
+                continue;
+            }
+            if (id != main_device && id_other != main_device) {
+                continue;
+            }
+
+            int can_access_peer;
+            CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
+            if (can_access_peer) {
+                if (enable_peer_access) {
+                    cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
+                    if (err != cudaErrorPeerAccessAlreadyEnabled) {
+                        CUDA_CHECK(err);
+                    } else {
+                        // reset the error
+                        cudaGetLastError();
+                    }
+                } else {
+                    cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
+                    if (err != cudaErrorPeerAccessNotEnabled) {
+                        CUDA_CHECK(err);
+                    } else {
+                        // reset the error
+                        cudaGetLastError();
+                    }
+                }
+            }
+        }
+    }
+
+    ggml_cuda_set_device(main_device);
+#endif // NDEBUG
+
+    peer_access_enabled = enable_peer_access;
+
+    GGML_UNUSED(main_device);
+}
+
+static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
+    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+    cudaMemcpy3DPeerParms p = {};
+    p.dstDevice = dstDevice;
+    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
+    p.srcDevice = srcDevice;
+    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
+    p.extent = make_cudaExtent(width, height, 1);
+    return cudaMemcpy3DPeerAsync(&p, stream);
+#else
+    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+    GGML_UNUSED(dstDevice);
+    GGML_UNUSED(srcDevice);
+    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+static void ggml_cuda_op_mul_mat(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+    quantize_cuda_t quantize_src1) {
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int64_t nrows1 = ggml_nrows(src1);
+
+    GGML_ASSERT(ne03 == ne13);
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    const int64_t nb2 = dst->nb[2];
+    const int64_t nb3 = dst->nb[3];
+
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer));
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer));
+    ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+    ggml_backend_cuda_buffer_context * dst_ctx  = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+
+    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+
+    const int64_t i02_divisor = ne12 / ne02;
+
+    const size_t src0_ts = ggml_type_size(src0->type);
+    const size_t src0_bs = ggml_blck_size(src0->type);
+    const size_t q8_1_ts = sizeof(block_q8_1);
+    const size_t q8_1_bs = QK8_1;
+
+    const bool src0_is_contiguous = ggml_is_contiguous(src0);
+    const bool src1_is_contiguous = ggml_is_contiguous(src1);
+
+    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
+    GGML_ASSERT(!(split && ne02 > 1));
+    GGML_ASSERT(!(split && ne03 > 1));
+    GGML_ASSERT(!(split && ne02 < ne12));
+
+    ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
+
+
+    std::array tensor_split;
+    if (split) {
+        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+        tensor_split = buft_ctx->tensor_split;
+    }
+
+    struct dev_data {
+        int cc;
+
+        ggml_cuda_pool_alloc   src0_dd_alloc;
+        ggml_cuda_pool_alloc src1_ddf_alloc;
+        ggml_cuda_pool_alloc  src1_ddq_alloc;
+        ggml_cuda_pool_alloc   dst_dd_alloc;
+
+        char  *  src0_dd = nullptr;
+        float * src1_ddf = nullptr; // float
+        char  * src1_ddq = nullptr; // q8_1
+        float *   dst_dd = nullptr;
+
+        int64_t  row_low;
+        int64_t row_high;
+    };
+
+    dev_data dev[GGML_CUDA_MAX_DEVICES];
+
+    int used_devices = 0;
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        dev[id].cc = ggml_cuda_info().devices[id].cc;
+
+        // by default, use all rows
+        dev[id].row_low  = 0;
+        dev[id].row_high = ne01;
+
+        // for multi GPU, get the row boundaries from tensor split
+        // and round to mul_mat_q tile sizes
+        if (split) {
+            const int64_t rounding = get_row_rounding(tensor_split);
+
+            if (id != 0) {
+                dev[id].row_low  = ne01*tensor_split[id];
+                if (dev[id].row_low < ne01) {
+                    dev[id].row_low -= dev[id].row_low % rounding;
+                }
+            }
+
+            if (id != ggml_backend_cuda_get_device_count() - 1) {
+                dev[id].row_high  = ne01*tensor_split[id + 1];
+                if (dev[id].row_high < ne01) {
+                    dev[id].row_high -= dev[id].row_high % rounding;
+                }
+            }
+        }
+    }
+
+    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+        if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+            continue;
+        }
+
+        used_devices++;
+
+        const bool src1_on_device = id == src1_ctx->device;
+        const bool  dst_on_device = id == dst_ctx->device;
+
+        ggml_cuda_set_device(id);
+        cudaStream_t stream = ctx.stream(id, 0);
+
+        if (src0_is_contiguous) {
+            dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
+        } else {
+            // If src0 is not contiguous it will be copied to a temporary buffer.
+            // This buffer needs to be cleared entirely because multiple regions will function as padding.
+            const size_t nbytes_data    = ggml_nbytes(src0);
+            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
+        // TODO: remove this for MUSA once the Guilty Lockup issue is resolved
+#ifndef GGML_USE_MUSA
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
+#else // GGML_USE_MUSA
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
+#endif // !GGML_USE_MUSA
+        }
+
+        // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
+        if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
+            const size_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
+        }
+
+        if (src1_on_device && src1_is_contiguous) {
+            dev[id].src1_ddf = (float *) src1->data;
+        } else {
+            dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
+        }
+
+        if (quantize_src1) {
+            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
+            }
+            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
+
+            if (src1_on_device && src1_is_contiguous) {
+                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
+                CUDA_CHECK(cudaGetLastError());
+            }
+        }
+
+        if (dst_on_device) {
+            dev[id].dst_dd = (float *) dst->data;
+        } else {
+            const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
+            dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
+        }
+    }
+
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signals that the main device has finished calculating the input data
+    if (split && used_devices > 1) {
+        ggml_cuda_set_device(ctx.device);
+        CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
+    }
+
+    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
+        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
+
+        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+            if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+                continue;
+            }
+
+            const bool src1_on_device = id == src1_ctx->device;
+            const bool  dst_on_device = id == dst_ctx->device;
+            const int64_t row_diff = dev[id].row_high - dev[id].row_low;
+
+            ggml_cuda_set_device(id);
+            cudaStream_t stream = ctx.stream(id, is);
+
+            // wait for main GPU data if necessary
+            if (split && (id != ctx.device || is != 0)) {
+                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
+            }
+
+            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+                const int64_t i03 = i0 / ne12;
+                const int64_t i02 = i0 % ne12;
+
+                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
+                } else {
+                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+                }
+
+                // for split tensors the data begins at i0 == i0_offset_low
+                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+                char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
+                float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+
+                // the main device memory buffer can be on VRAM scratch, with space for all partial results
+                // in that case an offset on dst_ddf_i is needed
+                if (id == ctx.device) {
+                    dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
+                }
+
+                // copy src0, src1 to device if necessary
+                if (src1_is_contiguous) {
+                    if (id != ctx.device) {
+                        if (quantize_src1) {
+                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
+                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);
+                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
+                                const size_t height = src1_padded_col_size/(4*QK8_1);
+                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
+                            } else {
+                                CUDA_CHECK(cudaMemcpyPeerAsync(
+                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+                            }
+                        } else {
+                            float * src1_ddf_i_source = (float *) src1->data;
+                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
+                                                            src1_ncols*ne10*sizeof(float), stream));
+                        }
+                    }
+                } else if (src1_on_device && !src1_is_contiguous) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                                src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+                } else {
+                    GGML_ABORT("fatal error");
+                }
+
+                if (quantize_src1 && !src1_is_contiguous) {
+                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
+                    CUDA_CHECK(cudaGetLastError());
+                }
+
+                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+                }
+
+                // do the computation
+                op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+                    dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
+                CUDA_CHECK(cudaGetLastError());
+
+                // copy dst to host or other device if necessary
+                if (!dst_on_device) {
+                    void * dst_off_device = dst->data;
+                    if (split) {
+                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+                        // dst is NOT transposed.
+                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
+                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
+                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
+                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
+                    } else {
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0;
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
+                    }
+                }
+
+                // add event for the main device to wait on until other device is done
+                if (split && (id != ctx.device || is != 0)) {
+                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
+                }
+            }
+        }
+    }
+
+    // main device waits for all other devices to be finished
+    if (split && ggml_backend_cuda_get_device_count() > 1) {
+        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+        is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
+
+        ggml_cuda_set_device(ctx.device);
+        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+            if (dev[id].row_low == dev[id].row_high) {
+                continue;
+            }
+            for (int64_t is = 0; is < is_max; ++is) {
+                CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
+            }
+        }
+    }
+}
+
+static __global__ void k_compute_batched_ptrs(
+        const half * src0_as_f16, const half * src1_as_f16, char * dst,
+        const void ** ptrs_src, void ** ptrs_dst,
+        int64_t ne12, int64_t ne13,
+        int64_t ne23,
+        size_t  nb02, size_t  nb03,
+        size_t  nb12, size_t  nb13,
+        size_t  nbd2, size_t  nbd3,
+        int64_t r2,   int64_t r3) {
+    int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
+
+    int64_t i03 = i13 / r3;
+    int64_t i02 = i12 / r2;
+
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
+}
+
+static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+
+    GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t ne_dst = ggml_nelements(dst);
+
+    cudaStream_t main_stream = ctx.stream();
+
+    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
+
+    void * src0_ddq = src0->data;
+    half * src0_f16 = (half *) src0_ddq;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf  = (float *) dst->data;
+
+    // convert src1 to fp16
+    ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool());
+    if (src1->type != GGML_TYPE_F16) {
+        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+        const int64_t ne_src1 = ggml_nelements(src1);
+        src1_f16_alloc.alloc(ne_src1);
+        GGML_ASSERT(to_fp16_cuda != nullptr);
+        to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+    }
+    half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
+
+    ggml_cuda_pool_alloc dst_f16(ctx.pool());
+    char * dst_t;
+
+    cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+    cudaDataType_t      cu_data_type    = CUDA_R_16F;
+
+    if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+    }
+
+    // dst strides
+    size_t nbd2 = dst->nb[2];
+    size_t nbd3 = dst->nb[3];
+
+    const half  alpha_f16 = 1.0f;
+    const half  beta_f16  = 0.0f;
+
+    const float alpha_f32 = 1.0f;
+    const float beta_f32  = 0.0f;
+
+    const void * alpha = &alpha_f16;
+    const void * beta  = &beta_f16;
+
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        dst_t = (char *) dst_f16.alloc(ne_dst);
+
+        nbd2 /= sizeof(float) / sizeof(half);
+        nbd3 /= sizeof(float) / sizeof(half);
+    } else {
+        dst_t = (char *) dst_ddf;
+
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        cu_data_type    = CUDA_R_32F;
+
+        alpha = &alpha_f32;
+        beta  = &beta_f32;
+    }
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+#if 0
+    // use cublasGemmEx
+    {
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                int i03 = i13 / r3;
+                int i02 = i12 / r2;
+
+                CUBLAS_CHECK(
+                        cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            alpha, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F,   nb01/sizeof(half),
+                                   (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F,   nb11/sizeof(float),
+                            beta,  (      char *)       dst_t + i12*nbd2          + i13*nbd3,          cu_data_type, ne01,
+                            cu_compute_type,
+                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+            }
+        }
+    }
+#else
+#ifdef GGML_USE_MUSA
+    GGML_ASSERT(false);
+#else // !GGML_USE_MUSA
+    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+        // use cublasGemmStridedBatchedEx
+        CUBLAS_CHECK(
+        cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                alpha, (const char *) src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00,  // strideA
+                       (const char *) src1_f16, CUDA_R_16F,   nb11/nb10, nb12/nb10,  // strideB
+                beta,  (      char *)    dst_t, cu_data_type, ne01,       nb2/nb0,   // strideC
+                ne12*ne13,
+                cu_compute_type,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+    } else {
+        // use cublasGemmBatchedEx
+        const int ne23 = ne12*ne13;
+
+        ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23);
+        ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
+
+        dim3 block_dims(ne13, ne12);
+        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                src0_f16, src1_f16, dst_t,
+                ptrs_src.get(), ptrs_dst.get(),
+                ne12, ne13,
+                ne23,
+                nb02, nb03,
+                src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
+                src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
+                nbd2, nbd3,
+                r2, r3);
+        CUDA_CHECK(cudaGetLastError());
+
+        CUBLAS_CHECK(
+        cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00,
+                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/nb10,
+                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
+                ne23,
+                cu_compute_type,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+    }
+#endif // GGML_USE_MUSA
+#endif
+
+    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+        to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
+    }
+}
+
+static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
+
+    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
+    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    bool use_mul_mat_q     = ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    bool any_gpus_with_slow_fp16   = false;
+    bool any_gpus_without_fp16_mma = false;
+
+    if (split) {
+        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+        auto & tensor_split = buft_ctx->tensor_split;
+        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+            // skip devices that are not going to do any work:
+            if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+                continue;
+            }
+
+            const int cc              = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
+            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+        }
+    } else {
+        const int cc              = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
+        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+    }
+
+    // debug helpers
+    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
+    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
+    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
+    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+
+    if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+        // the custom F16 vector kernel can be used over batched cuBLAS GEMM
+        // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
+        ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
+               && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // general KQ + KQV multi-batch without FlashAttention
+        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
+    } else if (use_mul_mat_vec) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr);
+    } else if (use_mul_mat_vec_q) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
+    } else if (use_mul_mat_q) {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
+    } else {
+        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
+    }
+}
+
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
+                                                 int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
+                                                 const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+                                                 int64_t ne11, int64_t ne10,
+                                                 size_t nb11, size_t nb12) {
+    int32_t iid1 = blockIdx.x;
+    int32_t id = blockIdx.y;
+
+    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+    if (row_id_i != i02) {
+        return;
+    }
+
+    const int64_t i11 = id % ne11;
+    const int64_t i12 = iid1;
+
+    __shared__ int src1_row;
+    if (threadIdx.x == 0) {
+        src1_row = atomicAdd(cur_src1_row, 1);
+        row_mapping[src1_row] = {id, iid1};
+    }
+    __syncthreads();
+
+    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+    for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
+        src1_row_contiguous[i] = src1_row_original[i];
+    }
+}
+
+static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
+                                                  const mmid_row_mapping * __restrict__ row_mapping,
+                                                  int64_t ne0,
+                                                  size_t nb1, size_t nb2) {
+    int32_t i = blockIdx.x;
+
+    const int32_t i1 = row_mapping[i].i1;
+    const int32_t i2 = row_mapping[i].i2;
+
+    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+    for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
+        dst_row_original[j] = dst_row_contiguous[j];
+    }
+}
+
+static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * ids  = dst->src[2];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t n_as = ne02;
+    const int64_t n_ids = ids->ne[0];
+
+    std::vector ids_host(ggml_nbytes(ids));
+    const char * ids_dev = (const char *) ids->data;
+    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+    CUDA_CHECK(cudaStreamSynchronize(stream));
+
+    ggml_tensor src0_row = *src0;
+    ggml_tensor src1_row = *src1;
+    ggml_tensor dst_row  = *dst;
+
+    char * src0_original = (char *) src0->data;
+    char * src1_original = (char *) src1->data;
+    char * dst_original  = (char *)  dst->data;
+
+    src0_row.ne[2] = 1;
+    src0_row.ne[3] = 1;
+    src0_row.nb[3] = nb02;
+
+    src1_row.ne[1] = 1;
+    src1_row.ne[2] = 1;
+    src1_row.ne[3] = 1;
+    src1_row.nb[2] = nb11;
+    src1_row.nb[3] = nb11;
+
+    dst_row.ne[1] = 1;
+    dst_row.ne[2] = 1;
+    dst_row.ne[3] = 1;
+    dst_row.nb[2] = nb1;
+    dst_row.nb[3] = nb1;
+
+    if (ne12 == 1) {
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = iid1;
+
+                const int64_t i1 = id;
+                const int64_t i2 = i12;
+
+                src0_row.data = src0_original + i02*nb02;
+                src1_row.data = src1_original + i11*nb11 + i12*nb12;
+                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
+
+                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+            }
+        }
+    } else {
+        ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+        ggml_cuda_pool_alloc  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+
+        src1_row.data = src1_contiguous.get();
+        dst_row.data  =  dst_contiguous.get();
+
+        for (int64_t i02 = 0; i02 < n_as; i02++) {
+            int64_t num_src1_rows = 0;
+
+            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+                for (int64_t id = 0; id < n_ids; id++) {
+                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+
+                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+
+                    if (row_id_i != i02) {
+                        continue;
+                    }
+
+                    num_src1_rows++;
+                }
+            }
+
+            if (num_src1_rows == 0) {
+                continue;
+            }
+
+            ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1);
+            ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows);
+            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
+
+            {
+                dim3 block_dims(std::min((unsigned int)ne10, 768u));
+                dim3 grid_dims(ids->ne[1], n_ids);
+                k_copy_src1_to_contiguous<<>>(
+                        src1_original, src1_contiguous.get(),
+                        dev_cur_src1_row.get(), dev_row_mapping.get(),
+                        ids_dev, i02, ids->nb[1], ids->nb[0],
+                        ne11, ne10,
+                        nb11, nb12);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            src0_row.data = src0_original + i02*nb02;
+
+            GGML_ASSERT(nb11 == sizeof(float)*ne10);
+            GGML_ASSERT(nb1 == sizeof(float)*ne0);
+
+            src1_row.ne[1] = num_src1_rows;
+            src1_row.nb[1] = nb11;
+            src1_row.nb[2] = num_src1_rows*nb11;
+            src1_row.nb[3] = num_src1_rows*nb11;
+
+            dst_row.ne[1] = num_src1_rows;
+            dst_row.nb[1] = nb1;
+            dst_row.nb[2] = num_src1_rows*nb1;
+            dst_row.nb[3] = num_src1_rows*nb1;
+
+            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+
+            {
+                dim3 block_dims(std::min((unsigned int)ne0, 768u));
+                dim3 grid_dims(num_src1_rows);
+                k_copy_dst_from_contiguous<<>>(
+                        dst_original, dst_contiguous.get(),
+                        dev_row_mapping.get(),
+                        ne0,
+                        nb1, nb2);
+                CUDA_CHECK(cudaGetLastError());
+            }
+        }
+    }
+}
+
+static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
+    // why is this here instead of mul_mat?
+    if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
+        ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
+    }
+
+    switch (dst->op) {
+        case GGML_OP_ARGMAX:
+            ggml_cuda_argmax(ctx, dst);
+            break;
+        case GGML_OP_COUNT_EQUAL:
+            ggml_cuda_count_equal(ctx, dst);
+            break;
+        case GGML_OP_REPEAT:
+            ggml_cuda_op_repeat(ctx, dst);
+            break;
+        case GGML_OP_REPEAT_BACK:
+            ggml_cuda_op_repeat_back(ctx, dst);
+            break;
+        case GGML_OP_GET_ROWS:
+            ggml_cuda_op_get_rows(ctx, dst);
+            break;
+        case GGML_OP_DUP:
+            ggml_cuda_dup(ctx, dst);
+            break;
+        case GGML_OP_CPY:
+            ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
+            break;
+        case GGML_OP_CONT:
+            ggml_cuda_dup(ctx, dst);
+            break;
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1: // TODO: more efficient implementation
+            ggml_cuda_op_add(ctx, dst);
+            break;
+        case GGML_OP_SUB:
+            ggml_cuda_op_sub(ctx, dst);
+            break;
+        case GGML_OP_ACC:
+            ggml_cuda_op_acc(ctx, dst);
+            break;
+        case GGML_OP_MUL:
+            ggml_cuda_op_mul(ctx, dst);
+            break;
+        case GGML_OP_DIV:
+            ggml_cuda_op_div(ctx, dst);
+            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(dst)) {
+                case GGML_UNARY_OP_NEG:
+                    ggml_cuda_op_neg(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_STEP:
+                    ggml_cuda_op_step(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_GELU:
+                    ggml_cuda_op_gelu(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    ggml_cuda_op_silu(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                    ggml_cuda_op_gelu_quick(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    ggml_cuda_op_tanh(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_RELU:
+                    ggml_cuda_op_relu(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SIGMOID:
+                    ggml_cuda_op_sigmoid(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                    ggml_cuda_op_hardsigmoid(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_HARDSWISH:
+                    ggml_cuda_op_hardswish(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_EXP:
+                    ggml_cuda_op_exp(ctx, dst);
+                    break;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NORM:
+            ggml_cuda_op_norm(ctx, dst);
+            break;
+        case GGML_OP_GROUP_NORM:
+            ggml_cuda_op_group_norm(ctx, dst);
+            break;
+        case GGML_OP_CONCAT:
+            ggml_cuda_op_concat(ctx, dst);
+            break;
+        case GGML_OP_UPSCALE:
+            ggml_cuda_op_upscale(ctx, dst);
+            break;
+        case GGML_OP_PAD:
+            ggml_cuda_op_pad(ctx, dst);
+            break;
+        case GGML_OP_UNPAD:
+            ggml_cuda_op_unpad(ctx, dst);
+            break;
+        case GGML_OP_ARANGE:
+            ggml_cuda_op_arange(ctx, dst);
+            break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            ggml_cuda_op_timestep_embedding(ctx, dst);
+            break;
+        case GGML_OP_LEAKY_RELU:
+            ggml_cuda_op_leaky_relu(ctx, dst);
+            break;
+        case GGML_OP_RMS_NORM:
+            ggml_cuda_op_rms_norm(ctx, dst);
+            break;
+        case GGML_OP_MUL_MAT:
+            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
+                GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
+                return false;
+            } else {
+                ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
+            }
+            break;
+        case GGML_OP_MUL_MAT_ID:
+            ggml_cuda_mul_mat_id(ctx, dst);
+            break;
+        case GGML_OP_OUT_PROD:
+            ggml_cuda_out_prod(ctx, dst);
+            break;
+        case GGML_OP_SCALE:
+            ggml_cuda_op_scale(ctx, dst);
+            break;
+        case GGML_OP_SQR:
+            ggml_cuda_op_sqr(ctx, dst);
+            break;
+        case GGML_OP_SQRT:
+            ggml_cuda_op_sqrt(ctx, dst);
+            break;
+        case GGML_OP_SIN:
+            ggml_cuda_op_sin(ctx, dst);
+            break;
+        case GGML_OP_COS:
+            ggml_cuda_op_cos(ctx, dst);
+            break;
+        case GGML_OP_CLAMP:
+            ggml_cuda_op_clamp(ctx, dst);
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+                break;
+        case GGML_OP_DIAG_MASK_INF:
+            ggml_cuda_op_diag_mask_inf(ctx, dst);
+            break;
+        case GGML_OP_SOFT_MAX:
+            ggml_cuda_op_soft_max(ctx, dst);
+            break;
+        case GGML_OP_ROPE:
+            ggml_cuda_op_rope(ctx, dst);
+            break;
+        case GGML_OP_IM2COL:
+            ggml_cuda_op_im2col(ctx, dst);
+            break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            ggml_cuda_op_conv_transpose_1d(ctx,dst);
+            break;
+        case GGML_OP_POOL_2D:
+            ggml_cuda_op_pool2d(ctx, dst);
+            break;
+        case GGML_OP_SUM:
+            ggml_cuda_op_sum(ctx, dst);
+            break;
+        case GGML_OP_SUM_ROWS:
+            ggml_cuda_op_sum_rows(ctx, dst);
+            break;
+        case GGML_OP_ARGSORT:
+            ggml_cuda_op_argsort(ctx, dst);
+            break;
+#if !defined(GGML_DISABLE_FLASH_ATTN)
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_cuda_flash_attn_ext(ctx, dst);
+            break;
+#endif
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+            ggml_cuda_cross_entropy_loss(ctx, dst);
+            break;
+        case GGML_OP_RWKV_WKV6:
+            ggml_cuda_op_rwkv_wkv6(ctx, dst);
+            break;
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+            ggml_cuda_cross_entropy_loss_back(ctx, dst);
+            break;
+        case GGML_OP_OPT_STEP_ADAMW:
+            ggml_cuda_opt_step_adamw(ctx, dst);
+            break;
+        default:
+            return false;
+    }
+
+    cudaError_t err = cudaGetLastError();
+    if (err != cudaSuccess) {
+        GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
+        CUDA_CHECK(err);
+    }
+
+    return true;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend
+
+static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    return cuda_ctx->name.c_str();
+}
+
+static void ggml_backend_cuda_free(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    delete cuda_ctx;
+    delete backend;
+}
+
+static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
+}
+
+static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
+}
+
+static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
+    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+    if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
+        return false;
+    }
+
+    if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
+        return false;
+    }
+
+    // device -> device copy
+    ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
+    ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
+
+    ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
+    ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
+
+    if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
+#endif
+        return false;
+    }
+
+    if (backend_src != backend_dst) {
+        // copy on src stream
+        if (cuda_ctx_src->device == cuda_ctx_dst->device) {
+            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+        } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+            return false;
+#else
+            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
+#endif
+        }
+
+        // record event on src stream after the copy
+        if (!cuda_ctx_src->copy_event) {
+            ggml_cuda_set_device(cuda_ctx_src->device);
+            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
+        }
+
+        CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
+
+        // wait on dst stream for the copy to complete
+        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
+    } else {
+        // src and dst are on the same backend
+        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+    }
+    return true;
+}
+
+static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
+
+    GGML_UNUSED(backend);
+}
+
+#ifdef USE_CUDA_GRAPH
+static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    graph_node_properties->node_address = node->data;
+    graph_node_properties->node_op = node->op;
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        graph_node_properties->ne[i] = node->ne[i];
+        graph_node_properties->nb[i] = node->nb[i];
+    }
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
+    }
+    memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
+}
+
+static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+    if (node->data != graph_node_properties->node_address &&
+          node->op != GGML_OP_CPY &&
+          node->op != GGML_OP_VIEW) {
+        return false;
+    }
+
+    if (node->op != graph_node_properties->node_op) {
+        return false;
+    }
+
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (node->ne[i] != graph_node_properties->ne[i]) {
+            return false;
+        }
+        if (node->nb[i] != graph_node_properties->nb[i]) {
+            return false;
+        }
+    }
+
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (node->src[i] &&
+            node->src[i]->data != graph_node_properties->src_address[i] &&
+            node->op != GGML_OP_CPY &&
+            node->op != GGML_OP_VIEW
+        ) {
+            return false;
+        }
+    }
+
+    if (node->op == GGML_OP_SCALE &&
+        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+        return false;
+    }
+
+    return true;
+}
+#endif
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+    // vector of pointers to CUDA cpy kernels, which are required to identify
+    // kernel parameters which need updated in the graph for each token
+    std::vector ggml_cuda_cpy_fn_ptrs;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+        }
+    }
+
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        if (cuda_ctx->cuda_graph->instance == nullptr) {
+            cuda_graph_update_required = true;
+        }
+
+        // Check if the graph size has changed
+        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+            cuda_graph_update_required = true;
+            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+        }
+
+        // Loop over nodes in GGML graph to determine if CUDA graph update is required
+        // and store properties to allow this comparison for the next token
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            bool has_matching_properties = true;
+            if (!cuda_graph_update_required) {
+                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+            }
+            if (!has_matching_properties) {
+                cuda_graph_update_required = true;
+            }
+            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        }
+
+        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            ggml_tensor * node = cgraph->nodes[i];
+
+            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+                continue;
+            }
+
+            if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
+                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+            }
+
+            if (node->op == GGML_OP_MUL_MAT_ID) {
+                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+            }
+
+            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+                // disable CUDA graphs for batch size > 1 for now.
+                // Changes in batch size or context size can cause changes to the grid size of some kernels.
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+            }
+
+            if (node->op == GGML_OP_CPY) {
+                // store the copy op parameter which changes with each token.
+                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+                // store a pointer to each copy op CUDA kernel to identify it later
+                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+                if (!ptr) {
+                    use_cuda_graph = false;
+#ifndef NDEBUG
+                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+#endif
+                } else {
+                    if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                        ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+                    }
+                }
+            }
+
+            if (!use_cuda_graph) {
+                break;
+            }
+        }
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (use_cuda_graph && cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
+        }
+
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+        }
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
+#else
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+    bool graph_evaluated_or_captured = false;
+
+    while (!graph_evaluated_or_captured) {
+        // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+        // With the use of CUDA graphs, the execution will be performed by the graph launch.
+        if (!use_cuda_graph || cuda_graph_update_required) {
+            for (int i = 0; i < cgraph->n_nodes; i++) {
+                ggml_tensor * node = cgraph->nodes[i];
+
+                if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+                    continue;
+                }
+
+#ifndef NDEBUG
+                assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+                for (int j = 0; j < GGML_MAX_SRC; j++) {
+                    if (node->src[j] != nullptr) {
+                        assert(node->src[j]->buffer);
+                        assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
+                               ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
+                    }
+                }
+#endif
+
+                bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
+                if (!ok) {
+                    GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+                }
+                GGML_ASSERT(ok);
+            }
+        }
+
+#ifdef USE_CUDA_GRAPH
+        if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+            if (cuda_ctx->cuda_graph->graph != nullptr) {
+                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
+                cuda_ctx->cuda_graph->graph = nullptr;
+            }
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+
+#if 0
+            if (disable_cuda_graphs_due_to_failed_capture) {
+                use_cuda_graph = false;
+                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
+#endif
+            } else {
+                graph_evaluated_or_captured = true; // CUDA graph has been captured
+            }
+#endif
+            graph_evaluated_or_captured = true; // CUDA graph has been captured
+        } else {
+            graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
+        }
+    }
+
+    if (use_cuda_graph) {
+        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
+            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        }
+
+        // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+
+        if (cuda_graph_update_required) {
+            // Extract nodes from graph
+            // First call with null argument gets number of nodes in graph
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+            // Subsequent call with non-null argument gets nodes
+            cuda_ctx->cuda_graph->nodes.clear();
+            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+            cuda_ctx->cuda_graph->params.clear();
+            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+            if (cuda_ctx->cuda_graph->num_nodes > 0) {
+                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
+
+                // Loop over nodes, and extract kernel parameters from each node
+                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                    cudaGraphNodeType node_type;
+                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                    if (node_type == cudaGraphNodeTypeKernel) {
+                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                        if (stat == cudaErrorInvalidDeviceFunction) {
+                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                            // We don't need to update blas nodes, so clear error and move on.
+                            cudaGetLastError();
+                        } else {
+                            GGML_ASSERT(stat == cudaSuccess);
+                        }
+                    }
+                }
+            }
+        }
+
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
+            int k = 0;
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+                }
+            }
+        }
+
+        // Update graph executable
+        cudaGraphExecUpdateResultInfo result_info;
+        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+        if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
+#endif
+            // The pre-existing graph exec cannot be updated due to violated constraints
+            // so instead clear error and re-instantiate
+            cudaGetLastError();
+            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+            cuda_ctx->cuda_graph->instance = nullptr;
+            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        } else {
+            GGML_ASSERT(stat == cudaSuccess);
+        }
+        // Launch graph
+        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+        graph_evaluated_or_captured = true;
+#endif // USE_CUDA_GRAPH
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
+}
+
+static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    if (ggml_backend_is_cuda(backend)) {
+        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
+    } else {
+#if 0
+        // untested
+        auto wait_fn = [](void * user_data) {
+            ggml_backend_event_t event = (ggml_backend_event_t)user_data;
+            ggml_backend_event_synchronize(event);
+        };
+
+        CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
+#endif
+        GGML_ABORT("fatal error");
+    }
+}
+
+static const ggml_backend_i ggml_backend_cuda_interface = {
+    /* .get_name                = */ ggml_backend_cuda_get_name,
+    /* .free                    = */ ggml_backend_cuda_free,
+    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,
+    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,
+    /* .cpy_tensor_async        = */ ggml_backend_cuda_cpy_tensor_async,
+    /* .synchronize             = */ ggml_backend_cuda_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
+    /* .event_record            = */ ggml_backend_cuda_event_record,
+    /* .event_wait              = */ ggml_backend_cuda_event_wait,
+};
+
+static ggml_guid_t ggml_backend_cuda_guid() {
+    static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
+    return &guid;
+}
+
+bool ggml_backend_is_cuda(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
+}
+
+int ggml_backend_cuda_get_device_count() {
+    return ggml_cuda_info().device_count;
+}
+
+void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
+    cudaDeviceProp prop;
+    CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
+    snprintf(description, description_size, "%s", prop.name);
+}
+
+void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
+    ggml_cuda_set_device(device);
+
+    CUDA_CHECK(cudaMemGetInfo(free, total));
+}
+
+bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
+    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+        return false;
+    }
+
+#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+    cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+
+        GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
+        return false;
+    }
+    return true;
+#else
+    return false;
+#endif
+}
+
+void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
+    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+        return;
+    }
+
+    cudaError_t err = cudaHostUnregister(buffer);
+    if (err != cudaSuccess) {
+        // clear the error
+        cudaGetLastError();
+    }
+}
+
+
+// backend device
+
+struct ggml_backend_cuda_device_context {
+    int device;
+    std::string name;
+    std::string description;
+};
+
+static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+    return ctx->name.c_str();
+}
+
+static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+    ggml_cuda_set_device(ctx->device);
+    CUDA_CHECK(cudaMemGetInfo(free, total));
+}
+
+static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return GGML_BACKEND_DEVICE_TYPE_GPU;
+}
+
+static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_cuda_device_get_name(dev);
+    props->description = ggml_backend_cuda_device_get_description(dev);
+    props->type        = ggml_backend_cuda_device_get_type(dev);
+    ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+    bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
+#ifdef GGML_CUDA_NO_PEER_COPY
+    bool events = false;
+#else
+    bool events = true;
+#endif
+
+    props->caps = {
+        /* .async                 = */ true,
+        /* .host_buffer           = */ host_buffer,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ events,
+    };
+}
+
+static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
+    GGML_UNUSED(params);
+    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+    return ggml_backend_cuda_init(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+    return ggml_backend_cuda_buffer_type(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return ggml_backend_cuda_host_buffer_type();
+}
+
+// TODO: move these functions here
+static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
+
+    // split buffers can only be used with GGML_OP_MUL_MAT
+    if (op->op != GGML_OP_MUL_MAT) {
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {
+                return false;
+            }
+        }
+    }
+
+    // check if all the sources are allocated on this device
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {
+            ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;
+            if (buft_ctx->device != dev_ctx->device) {
+                return false;
+            }
+        }
+    }
+
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_EXP:
+                    return ggml_is_contiguous(op->src[0]);
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            {
+                struct ggml_tensor * a = op->src[0];
+                struct ggml_tensor * b = op->src[1];
+                // for small weight matrices the active device can end up without any rows, don't use row split in those cases
+                // this avoids some edge cases (and the performance would not be good anyways)
+                if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
+                    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
+                    int64_t row_low;
+                    int64_t row_high;
+                    get_row_split(&row_low, &row_high, a, buft_ctx->tensor_split, dev_ctx->device);
+                    if (row_low == row_high) {
+                        return false;
+                    }
+                }
+                if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
+                    return false;
+                }
+                if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
+                    return false;
+                }
+#ifdef GGML_USE_MUSA
+                if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
+                    !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
+                    return false;
+                }
+#endif // GGML_USE_MUSA
+                switch (a->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q2_K:
+                    case GGML_TYPE_Q3_K:
+                    case GGML_TYPE_Q4_K:
+                    case GGML_TYPE_Q5_K:
+                    case GGML_TYPE_Q6_K:
+                    case GGML_TYPE_Q8_K:
+                    case GGML_TYPE_IQ1_M:
+                    case GGML_TYPE_IQ1_S:
+                    case GGML_TYPE_IQ2_S:
+                    case GGML_TYPE_IQ2_XS:
+                    case GGML_TYPE_IQ2_XXS:
+                    case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ3_XXS:
+                    case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_IQ4_XS:
+                    case GGML_TYPE_BF16:
+#ifdef GGML_USE_MUSA
+                        if (a->type == GGML_TYPE_Q3_K) {
+                            return false;
+                        }
+#endif // GGML_USE_MUSA
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_OUT_PROD:
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+        case GGML_OP_GET_ROWS:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        return true;
+                    default:
+                        return false;
+                }
+            } break;
+        case GGML_OP_CPY:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
+                    return true;
+                }
+                return false;
+            } break;
+        case GGML_OP_DUP:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
+            {
+                return true;
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_REPEAT_BACK:
+                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
+        case GGML_OP_CONCAT:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                ggml_type src1_type = op->src[1]->type;
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
+                return false;
+            } break;
+        case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
+            return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+        case GGML_OP_CLAMP:
+            return true;
+        case GGML_OP_CONT:
+            return op->src[0]->type != GGML_TYPE_BF16;
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX:
+            return true;
+        case GGML_OP_ROPE:
+            return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_IM2COL:
+        case GGML_OP_POOL_2D:
+        case GGML_OP_SUM:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_ACC:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_UNPAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_LEAKY_RELU:
+        case GGML_OP_RWKV_WKV6:
+            return true;
+        case GGML_OP_FLASH_ATTN_EXT: {
+#ifndef FLASH_ATTN_AVAILABLE
+            return false;
+#endif
+            if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
+                return false;
+            }
+            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
+                return true;
+            }
+            if (op->src[0]->ne[0] == 128) {
+                return true;
+            }
+            if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
+                return true;
+            }
+            const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
+            return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+        }
+        case GGML_OP_CROSS_ENTROPY_LOSS:
+        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+        case GGML_OP_OPT_STEP_ADAMW:
+            return true;
+        default:
+            return false;
+    }
+}
+
+static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
+}
+
+static int64_t get_op_batch_size(const ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_GET_ROWS:
+            return 0;
+        case GGML_OP_MUL_MAT:
+            return op->ne[1];
+        case GGML_OP_MUL_MAT_ID:
+        case GGML_OP_ROPE:
+            return op->ne[2];
+        default:
+            return ggml_nrows(op);
+    }
+}
+
+static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    const int min_batch_size = 32;
+
+    return get_op_batch_size(op) >= min_batch_size;
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
+#ifdef GGML_CUDA_NO_PEER_COPY
+    return nullptr;
+#else
+    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;
+
+    ggml_cuda_set_device(dev_ctx->device);
+
+    cudaEvent_t event;
+    CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
+
+    return new ggml_backend_event {
+        /* .device  = */ dev,
+        /* .context = */ event,
+    };
+#endif
+}
+
+static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    GGML_UNUSED(dev);
+
+    CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
+    delete event;
+}
+
+static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    GGML_UNUSED(dev);
+    CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
+}
+
+static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
+    /* .get_name                = */ ggml_backend_cuda_device_get_name,
+    /* .get_description         = */ ggml_backend_cuda_device_get_description,
+    /* .get_memory              = */ ggml_backend_cuda_device_get_memory,
+    /* .get_type                = */ ggml_backend_cuda_device_get_type,
+    /* .get_props               = */ ggml_backend_cuda_device_get_props,
+    /* .init_backend            = */ ggml_backend_cuda_device_init_backend,
+    /* .get_buffer_type         = */ ggml_backend_cuda_device_get_buffer_type,
+    /* .get_host_buffer_type    = */ ggml_backend_cuda_device_get_host_buffer_type,
+    /* .buffer_from_host_ptr    = */ NULL,
+    /* .supports_op             = */ ggml_backend_cuda_device_supports_op,
+    /* .supports_buft           = */ ggml_backend_cuda_device_supports_buft,
+    /* .offload_op              = */ ggml_backend_cuda_device_offload_op,
+    /* .event_new               = */ ggml_backend_cuda_device_event_new,
+    /* .event_free              = */ ggml_backend_cuda_device_event_free,
+    /* .event_synchronize       = */ ggml_backend_cuda_device_event_synchronize,
+};
+
+// backend reg
+
+struct ggml_backend_cuda_reg_context {
+    std::vector devices;
+};
+
+static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) {
+    GGML_UNUSED(reg);
+    return GGML_CUDA_NAME;
+}
+
+static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) {
+    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
+    return ctx->devices.size();
+}
+
+static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
+    GGML_ASSERT(index < ctx->devices.size());
+    return ctx->devices[index];
+}
+
+static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t reg) {
+    static std::vector features = []() {
+        std::vector features;
+    #define _STRINGIFY(...) #__VA_ARGS__
+    #define STRINGIFY(...) _STRINGIFY(__VA_ARGS__)
+
+    #ifdef __CUDA_ARCH_LIST__
+        features.push_back({ "ARCHS", STRINGIFY(__CUDA_ARCH_LIST__) });
+    #endif
+
+    #ifdef GGML_CUDA_FORCE_MMQ
+        features.push_back({ "FORCE_MMQ", "1" });
+    #endif
+
+    #ifdef GGML_CUDA_FORCE_CUBLAS
+        features.push_back({ "FORCE_CUBLAS", "1" });
+    #endif
+
+    #ifdef GGML_CUDA_NO_VMM
+        features.push_back({ "NO_VMM", "1" });
+    #endif
+
+    #ifdef GGML_CUDA_NO_PEER_COPY
+        features.push_back({ "NO_PEER_COPY", "1" });
+    #endif
+
+    #ifdef GGML_CUDA_F16
+        features.push_back({ "F16", "1" });
+    #endif
+
+    #ifdef GGML_CUDA_USE_GRAPHS
+        features.push_back({ "USE_GRAPHS", "1" });
+    #endif
+
+    #ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE
+        features.push_back({ "PEER_MAX_BATCH_SIZE", STRINGIFY(GGML_CUDA_PEER_MAX_BATCH_SIZE) });
+    #endif
+
+    #ifdef GGML_CUDA_FA_ALL_QUANTS
+        features.push_back({ "FA_ALL_QUANTS", "1" });
+    #endif
+
+    #undef _STRINGIFY
+    #undef STRINGIFY
+
+        features.push_back({ nullptr, nullptr });
+
+        return features;
+    }();
+
+    return features.data();
+
+    GGML_UNUSED(reg);
+}
+
+static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    GGML_UNUSED(reg);
+    if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
+        return (void *)ggml_backend_cuda_split_buffer_type;
+    }
+    if (strcmp(name, "ggml_backend_register_host_buffer") == 0) {
+        return (void *)ggml_backend_cuda_register_host_buffer;
+    }
+    if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) {
+        return (void *)ggml_backend_cuda_unregister_host_buffer;
+    }
+    if (strcmp(name, "ggml_backend_get_features") == 0) {
+        return (void *)ggml_backend_cuda_get_features;
+    }
+    return nullptr;
+}
+
+static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
+    /* .get_name          = */ ggml_backend_cuda_reg_get_name,
+    /* .get_device_count  = */ ggml_backend_cuda_reg_get_device_count,
+    /* .get_device        = */ ggml_backend_cuda_reg_get_device,
+    /* .get_proc_address  = */ ggml_backend_cuda_reg_get_proc_address,
+};
+
+// backend registry
+ggml_backend_reg_t ggml_backend_cuda_reg() {
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
+    {
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+        if (!initialized) {
+            ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
+
+            for (int i = 0; i < ggml_cuda_info().device_count; i++) {
+                ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
+                dev_ctx->device = i;
+                dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
+
+                ggml_cuda_set_device(i);
+                cudaDeviceProp prop;
+                CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
+                dev_ctx->description = prop.name;
+
+                ggml_backend_dev_t dev = new ggml_backend_device {
+                    /* .iface   = */ ggml_backend_cuda_device_interface,
+                    /* .reg     = */ ®,
+                    /* .context = */ dev_ctx
+                };
+                ctx->devices.push_back(dev);
+            }
+
+            reg = ggml_backend_reg {
+                /* .api_version = */ GGML_BACKEND_API_VERSION,
+                /* .iface       = */ ggml_backend_cuda_reg_interface,
+                /* .context     = */ ctx
+            };
+        }
+
+        initialized = true;
+    }
+
+    return ®
+}
+
+ggml_backend_t ggml_backend_cuda_init(int device) {
+    if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
+        GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device);
+        return nullptr;
+    }
+
+    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
+    if (ctx == nullptr) {
+        GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
+        return nullptr;
+    }
+
+    ggml_backend_t cuda_backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_cuda_guid(),
+        /* .interface = */ ggml_backend_cuda_interface,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
+        /* .context   = */ ctx,
+    };
+
+    return cuda_backend;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg)
diff --git a/llama/ggml-cuda/im2col.cu b/llama/ggml-cuda/im2col.cu
new file mode 100644
index 000000000..0ceaa02c9
--- /dev/null
+++ b/llama/ggml-cuda/im2col.cu
@@ -0,0 +1,129 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "im2col.cuh"
+
+template 
+static  __global__ void im2col_kernel(
+        const float * x, T * dst, int64_t batch_offset,
+        int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
+        int s0, int s1, int p0, int p1, int d0, int d1) {
+    const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+    if (i >= pelements) {
+        return;
+    }
+
+    const int64_t  ksize = OW * (KH > 1 ? KW : 1);
+    const int64_t  kx = i / ksize;
+    const int64_t  kd = kx * ksize;
+    const int64_t  ky = (i - kd) / OW;
+    const int64_t  ix = i % OW;
+
+    const int64_t  oh = blockIdx.y;
+    const int64_t  batch = blockIdx.z / IC;
+    const int64_t  ic = blockIdx.z % IC;
+
+    const int64_t iiw = ix * s0 + kx * d0 - p0;
+    const int64_t iih = oh * s1 + ky * d1 - p1;
+
+    const int64_t offset_dst =
+        ((batch * OH + oh) * OW + ix) * CHW +
+        (ic * (KW * KH) + ky * KW + kx);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        dst[offset_dst] = 0.0f;
+    } else {
+        const int64_t offset_src = ic * offset_delta + batch * batch_offset;
+        dst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template 
+static void im2col_cuda(const float * x, T* dst,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+    const int parallel_elements = OW * KW * KH;
+    const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+    dim3 block_nums(num_blocks, OH, batch * IC);
+    im2col_kernel<<>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+}
+
+static void im2col_cuda_f16(const float * x, half * dst,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+    im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+}
+
+static void im2col_cuda_f32(const float * x, float * dst,
+    int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+    int64_t batch, int64_t batch_offset, int64_t offset_delta,
+    int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+    im2col_cuda(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
+}
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const float * src1_d = (const float *)src1->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+    const int64_t IC = src1->ne[is_2D ? 2 : 1];
+    const int64_t IH = is_2D ? src1->ne[1] : 1;
+    const int64_t IW =         src1->ne[0];
+
+    const int64_t KH = is_2D ? src0->ne[1] : 1;
+    const int64_t KW =         src0->ne[0];
+
+    const int64_t OH = is_2D ? dst->ne[2] : 1;
+    const int64_t OW =         dst->ne[1];
+
+    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch        = src1->ne[is_2D ? 3 : 2];
+    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+
+    if(dst->type == GGML_TYPE_F16) {
+        im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+    } else {
+        im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
+    }
+}
diff --git a/llama/ggml-cuda/im2col.cuh b/llama/ggml-cuda/im2col.cuh
new file mode 100644
index 000000000..2c64c16b1
--- /dev/null
+++ b/llama/ggml-cuda/im2col.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_IM2COL_BLOCK_SIZE 256
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/mma.cuh b/llama/ggml-cuda/mma.cuh
new file mode 100644
index 000000000..557cdcd1e
--- /dev/null
+++ b/llama/ggml-cuda/mma.cuh
@@ -0,0 +1,247 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+struct mma_int_A_I16K4 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 4;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_A_I16K8 {
+    static constexpr int I  = 16;
+    static constexpr int K  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE)
+        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
+        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_i(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+
+    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
+        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
+    }
+};
+
+struct mma_int_B_J8K4 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 4;
+    static constexpr int ne = 1;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int /* l */) {
+        const int ret = threadIdx.x % K;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride;
+        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
+            : "+r"(x[0])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_B_J8K8 {
+    static constexpr int J  = 8;
+    static constexpr int K  = 8;
+    static constexpr int ne = 2;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_j(const int /* l */) {
+        const int ret = threadIdx.x / (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_k(const int l) {
+        const int ret = l * (K/2) + threadIdx.x % (K/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  K);
+        return ret;
+    }
+
+    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
+#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
+        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
+        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "+r"(x[0]), "+r"(x[1])
+            : "l"(xs));
+#else
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            x[l] = xs0[get_j(l)*stride + get_k(l)];
+        }
+#endif // defined(INT8_MMA_AVAILABLE)
+    }
+};
+
+struct mma_int_C_I16J8 {
+    static constexpr int I  = 16;
+    static constexpr int J  = 8;
+    static constexpr int ne = 4;
+
+    int x[ne] = {0};
+
+    static __device__ __forceinline__ int get_i(const int l) {
+        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  I);
+        return ret;
+    }
+
+    static __device__ __forceinline__ int get_j(const int l) {
+        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
+        GGML_CUDA_ASSUME(ret >= 0);
+        GGML_CUDA_ASSUME(ret <  J);
+        return ret;
+    }
+
+    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
+#ifdef INT8_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+#else
+        // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[0]), "+r"(x[1])
+            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+        asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+            : "+r"(x[2]), "+r"(x[3])
+            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(mma_A);
+        GGML_UNUSED(mma_B);
+        NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+    }
+};
diff --git a/llama/ggml-cuda/mmq.cu b/llama/ggml-cuda/mmq.cu
new file mode 100644
index 000000000..0dc63b31b
--- /dev/null
+++ b/llama/ggml-cuda/mmq.cu
@@ -0,0 +1,178 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "mmq.cuh"
+
+void ggml_cuda_op_mul_mat_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    const int64_t ne00 = src0->ne[0];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    const int64_t row_diff = row_high - row_low;
+    const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
+
+    int id = ggml_cuda_get_device();
+    const int compute_capability = ggml_cuda_info().devices[id].cc;
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+    // The stream-k decomposition is only faster for recent NVIDIA GPUs.
+    // Also its fixup needs to allocate a temporary buffer in the memory pool.
+    // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
+    const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
+    const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_q_case(ctx, args, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddf_i);
+}
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+    return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+    bool mmq_supported;
+
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            mmq_supported = true;
+            break;
+        default:
+            mmq_supported = false;
+            break;
+    }
+
+    if (!mmq_supported) {
+        return false;
+    }
+
+    if (int8_mma_available(cc)) {
+        return true;
+    }
+
+    if (cc < GGML_CUDA_CC_DP4A) {
+        return false;
+    }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+    return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+    if (cc < GGML_CUDA_CC_OFFSET_AMD) {
+        return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    }
+
+    return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+}
diff --git a/llama/ggml-cuda/mmq.cuh b/llama/ggml-cuda/mmq.cuh
new file mode 100644
index 000000000..1da4680a2
--- /dev/null
+++ b/llama/ggml-cuda/mmq.cuh
@@ -0,0 +1,2962 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "common.cuh"
+#include "vecdotq.cuh"
+#include "mma.cuh"
+
+#include 
+#include 
+
+#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+#define MMQ_ITER_K 256
+#define MMQ_NWARPS 8
+
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
+
+enum mmq_q8_1_ds_layout {
+    MMQ_Q8_1_DS_LAYOUT_D4,
+    MMQ_Q8_1_DS_LAYOUT_DS4,
+    MMQ_Q8_1_DS_LAYOUT_D2S6,
+};
+
+struct block_q8_1_mmq {
+    // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
+    // The y float data is first grouped as blocks of 128 values.
+    // These blocks are then treated as individual data values and transposed.
+    //
+    // To avoid shared memory bank conflicts each block is padded with 16 bytes.
+    // This padding is also used to store block scales/partial sums.
+    // The scales multiplied with the quantized data are equal to the unquantized values.
+    // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
+    //     and are only needed for performance reasons.
+    //
+    // The exact data stored depends on the x data type.
+    union {
+        float d4[4];    // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
+        half2 ds4[4];   // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
+        half  d2s6[8];  // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
+                        //     stored as d0,d1,s1,s2,s3,s4,s5
+    };
+    int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
+};
+static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
+
+static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
+    switch (type_x) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q5_0:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q5_1:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q8_0:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q2_K:
+            return MMQ_Q8_1_DS_LAYOUT_D2S6;
+        case GGML_TYPE_Q3_K:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ3_S:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_IQ1_S:
+            return MMQ_Q8_1_DS_LAYOUT_DS4;
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+struct tile_x_sizes {
+    int qs;
+    int dm;
+    int sc;
+};
+
+static constexpr int get_mmq_x_max_host(const int cc) {
+    return int8_mma_available(cc) ? 128 :
+#ifdef GGML_CUDA_FORCE_MMQ
+        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128                     : 64;
+#else
+        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+#endif // GGML_CUDA_FORCE_MMQ
+}
+
+static constexpr __device__ int get_mmq_x_max_device() {
+#ifdef INT8_MMA_AVAILABLE
+    return 128;
+#else // INT8_MMA_AVAILABLE
+
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+    return 128;
+#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#ifdef GGML_CUDA_FORCE_MMQ
+    return MMQ_DP4A_MAX_BATCH_SIZE;
+#else // GGML_CUDA_FORCE_MMQ
+    return 128;
+#endif // GGML_CUDA_FORCE_MMQ
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+
+    return 64;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#endif // INT8_MMA_AVAILABLE
+}
+
+static constexpr int get_mmq_y_host(const int cc) {
+    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
+}
+
+static constexpr __device__ int get_mmq_y_device() {
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA1)
+    return 64;
+#else
+    return 128;
+#endif // defined RDNA1
+#else
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+    return 128;
+#else
+    return 64;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+}
+
+#define MMQ_DP4A_TXS_Q4_0    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_0   + mmq_y/QI4_0,     0}
+#define MMQ_DP4A_TXS_Q4_1    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_1   + mmq_y/QI4_1,     0}
+#define MMQ_DP4A_TXS_Q8_0    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE         + mmq_y,           0}
+#define MMQ_DP4A_TXS_Q3_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y,                                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K    tile_x_sizes{mmq_y*WARP_SIZE   + mmq_y, mmq_y*WARP_SIZE/QI4_K,                     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K   + mmq_y/QI5_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K    tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K   + mmq_y/QI6_K,     mmq_y*WARP_SIZE/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
+        type == GGML_TYPE_Q4_1    ? MMQ_DP4A_TXS_Q4_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_DP4A_TXS_Q8_1 :
+        type == GGML_TYPE_Q8_0    ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q2_K    ? MMQ_DP4A_TXS_Q2_K :
+        type == GGML_TYPE_Q3_K    ? MMQ_DP4A_TXS_Q3_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_DP4A_TXS_Q4_K :
+        type == GGML_TYPE_Q5_K    ? MMQ_DP4A_TXS_Q5_K :
+        type == GGML_TYPE_Q6_K    ? MMQ_DP4A_TXS_Q6_K :
+        type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ2_XS  ? MMQ_DP4A_TXS_Q8_0_16 :
+        type == GGML_TYPE_IQ2_S   ? MMQ_DP4A_TXS_Q8_0_16 :
+        type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ3_S   ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ1_S   ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q8_0 :
+        tile_x_sizes{0, 0, 0};
+}
+
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0                 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE                         + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2                       + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K     + WARP_SIZE/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q8_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q2_K    ? MMQ_MMA_TILE_X_K_Q2_K :
+        type == GGML_TYPE_Q3_K    ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q8_1 :
+        type == GGML_TYPE_Q6_K    ? MMQ_MMA_TILE_X_K_Q6_K :
+        type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ2_XS  ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_IQ2_S   ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ3_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ1_S   ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q8_0 :
+        0;
+}
+
+#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+}
+
+#ifdef INT8_MMA_AVAILABLE
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+    return mmq_x >= 48 ? 16 : 8;
+}
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
+    return 8;
+}
+#endif // INT8_MMA_AVAILABLE
+
+// ------------------------------------------------------------
+
+template  static __device__ __forceinline__ void load_tiles_q4_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_0;
+    const int kqsx = threadIdx.x % QI4_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+        const int qs0 = get_int_b2(bxi->qs, kqsx);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
+#else
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+        int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+                int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+                for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
+                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
+                }
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl
+                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u,
+                     x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_q4_1(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_1;
+    const int kqsx = threadIdx.x % QI4_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+        const int qs0 = get_int_b4(bxi->qs, kqsx);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
+        int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+                int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+                for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+                    u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs +  l];
+                    u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
+                }
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl
+                    (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u,
+                     x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_q5_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI5_0;
+    const int kqsx = threadIdx.x % QI5_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_b2(bxi->qs, kqsx);
+        const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+
+        int qs0 = (ql >>  0)   & 0x0F0F0F0F;
+        qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
+        qs0    |= (qh << 11)   & 0x00001000;  // 1 -> 12
+        qs0    |= (qh << 18)   & 0x00100000;  // 2 -> 20
+        qs0    |= (qh << 25)   & 0x10000000;  // 3 -> 28
+        qs0     = __vsubss4(qs0, 0x10101010); // subtract 16
+
+        int qs1 = (ql >>  4)   & 0x0F0F0F0F;
+        qs1    |= (qh >> 12)   & 0x00000010;  // 16 ->  4
+        qs1    |= (qh >>  5)   & 0x00001000;  // 17 -> 12
+        qs1    |= (qh <<  2)   & 0x00100000;  // 18 -> 20
+        qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
+        qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
+        int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_q5_1(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI5_1;
+    const int kqsx = threadIdx.x % QI5_1;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+
+        const int ql = get_int_b4(bxi->qs, kqsx);
+        const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+
+        int qs0 = (ql >>  0) & 0x0F0F0F0F;
+        qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
+        qs0    |= (qh << 11) & 0x00001000; // 1 -> 12
+        qs0    |= (qh << 18) & 0x00100000; // 2 -> 20
+        qs0    |= (qh << 25) & 0x10000000; // 3 -> 28
+
+        int qs1 = (ql >>  4) & 0x0F0F0F0F;
+        qs1    |= (qh >> 12) & 0x00000010; // 16 ->  4
+        qs1    |= (qh >>  5) & 0x00001000; // 17 -> 12
+        qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
+        qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
+        int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
+#else
+        x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_q8_0(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_tile + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI8_0;
+    const int kqsx = threadIdx.x % QI8_0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
+        x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) {
+        int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
+#else
+        x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl
+                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE],
+                     x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]);
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+    const half2 * y_ds = (const half2 *) y;
+
+    mma_A A[ntx][WARP_SIZE/QI8_0];
+    float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+                const int k0 = k00 + k01;
+
+                dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+            mma_B  B;
+            float dB[mma_C::ne/2];
+
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+                    dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+                } else {
+                    dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+                }
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/QI8_0], B);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+                }
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl
+                    (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                    x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    typedef mma_int_A_I16K8 mma_A;
+    typedef mma_int_B_J8K8  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_dm = (const half2 *) y;
+
+    mma_A    A[ntx][WARP_SIZE/QI8_1];
+    float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+
+    const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+                const int k0 = k00 + k01;
+
+                dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            mma_B    B;
+            float2 dsB[mma_C::ne/2];
+
+            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C;
+                C.mma_K8(A[n][k01/QI8_1], B);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+                }
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0],
+                    &y_qs[j*MMQ_TILE_Y_K + k01],
+                    &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+                    y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    float  dA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
+                const int k0 = k00 + k01;
+
+                dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+            mma_B B[2];
+            float dB[mma_C::ne/2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+                }
+            }
+        }
+    }
+#else
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template  static __device__ __forceinline__ void load_tiles_q2_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI2_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) {
+        int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
+
+        const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+
+#pragma unroll
+        for (int l = 0; l < QR2_K; ++l) {
+            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+            const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int sc_m = bxi->scales[kqsx];
+#ifdef FAST_FP16_AVAILABLE
+        const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
+#else
+        const float2 bxi_dmf = __half22float2(bxi->dm);
+        const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
+#endif // FAST_FP16_AVAILABLE
+
+#ifdef INT8_MMA_AVAILABLE
+        x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
+#else
+        x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    float2 y_df[mmq_x/nwarps];
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+    }
+
+#pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                if (k01 < WARP_SIZE/2) {
+                    constexpr int ns = 2;
+                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+                } else {
+                    constexpr int ns = 1;
+                    sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
+                        &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+                        &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+                        &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+                }
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_A_I16K8 mma_A_K8;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    float  dA[ntx][mma_C::ne/2][8];
+    float  mA[ntx][mma_C::ne/2][8];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            const int k0 = k00 + k01;
+
+            ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+        }
+    }
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+#pragma unroll
+            for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
+                const int k0 = k00 + k01;
+
+                const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
+
+                dA[n][l][k01/(QI8_1/2)] = dm.x;
+                mA[n][l][k01/(QI8_1/2)] = dm.y;
+            }
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float2 dB[mma_C::ne/2];
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int j = j0 + mma_C::get_j(l);
+
+            dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
+            mma_B B[2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+
+            mma_C Cm[2];
+            if (k01 >= WARP_SIZE * 3/4) {
+                mma_A A1;
+                A1.x[0] = 0x01010101;
+                A1.x[1] = 0x01010101;
+                Cm[0].mma_K4(A1, B[0]);
+                Cm[1].mma_K4(A1, B[1]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C Cd[2];
+
+                Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
+                    if (k01 >= WARP_SIZE * 3/4) {
+                        tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
+                    }
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                }
+            }
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
+            float2 sB[mma_C::ne/2];
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+                }
+            }
+        }
+    }
+#else
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template  static __device__ __forceinline__ void load_tiles_q3_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI3_K;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) {
+        int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        const int x_ql_0 = get_int_b2(bxi->qs,    kqsx);
+        const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+
+#pragma unroll
+        for (int l = 0; l < QR3_K; ++l) {
+            const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+            const int x_ql_k =  (x_ql_0 >> (2*l))       & 0x03030303;
+            const int x_qh_k = ((x_qh_0 >>    l)  << 2) & 0x04040404;
+
+            const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
+#endif // INT8_MMA_AVAILABLE
+        }
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+        int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+
+        const int ksc_low = ksc % (QI3_K/8);
+        const int shift_low = 4 * (ksc / (QI3_K/8));
+        const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+        const int ksc_high = QI3_K/8;
+        const int shift_high = 2 * ksc;
+        const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+        const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+#ifdef INT8_MMA_AVAILABLE
+        const int8_t * sc8 = (const int8_t *) ≻
+        const float d = bxi->d;
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
+        }
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#ifndef INT8_MMA_AVAILABLE
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
+        int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+        x_df[i] = bxi->d;
+    }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4;
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
+                    &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+                    x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
+    // scale arrangement after the following two lines:
+    //   - ksc == 0: sc0, sc1, sc2, sc3
+    //   - ksc == 1: sc4, sc5, sc6, sc7
+    //   - ksc == 2:  m0,  m1,  m2,  m3
+    //   - ksc == 3:  m4,  m5,  m6,  m7
+    return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
+           ((scales[ksc/2]              >> (2 * (ksc % 2)))       & 0x30303030);  // upper 2 bits
+}
+
+template  static __device__ __forceinline__ void load_tiles_q4_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+        const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+        x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#ifdef INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+        const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+        const uint8_t * sc8 = (const uint8_t *) &sc32;
+        const uint8_t *  m8 = (const uint8_t *)  &m32;
+
+        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+        }
+    }
+
+#else
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) {
+        int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+        x_dm[i] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+    }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
+                    &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_q5_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_dm = (half2 *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_dm + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+        const int ky = QR5_K*threadIdx.x;
+
+        const int ql = get_int_b4(bxi->qs, threadIdx.x);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4));
+        const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+        const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+        const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
+        const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#ifdef INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
+        int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+        const int ksc = threadIdx.x % (WARP_SIZE/16);
+
+        const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+        const int  m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+        const uint8_t * sc8 = (const uint8_t *) &sc32;
+        const uint8_t *  m8 = (const uint8_t *)  &m32;
+
+        const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+        for (int l = 0; l < sizeof(int); ++l) {
+            x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+        }
+    }
+
+#else
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) {
+        int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        x_dm[i] = bxi->dm;
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) {
+        int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+        const int * scales = (const int *) bxi->scales;
+
+        const int ksc = threadIdx.x % (WARP_SIZE/8);
+        const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+        x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
+    }
+#endif // INT8_MMA_AVAILABLE
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_dm + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+                    x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_q6_K(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+    int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+    int   * x_sc = (int   *) (x_df + txs.dm);
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
+
+        const int ql = get_int_b2(bxi->ql, threadIdx.x);
+        const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+        const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4));
+        const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030;
+        const int qh1 =  (qh >> ((threadIdx.x & 0x08) >> 2))       & 0x30303030;
+
+        const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
+        const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+        x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+        int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
+#else
+        x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+        int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#else
+        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + txs.qs;
+    const int   * x_sc = (const int   *) x_df + txs.dm;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+// #pragma unroll
+    for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+        const int k0 = k00 + k01;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]);
+
+                sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
+                    &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+                    x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+            }
+        }
+    }
+}
+
+template 
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
+    const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
+#ifdef INT8_MMA_AVAILABLE
+
+    typedef mma_int_A_I16K4 mma_A;
+    typedef mma_int_B_J8K4  mma_B;
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+
+    const int   * x_qs = (const int   *) x;
+    const float * x_df = (const float *) x_qs + WARP_SIZE*2;
+    const int   * x_sc = (const int   *) x_df + WARP_SIZE/QI6_K;
+    const int   * y_qs = (const int   *) y + 4;
+    const float * y_df = (const float *) y;
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+
+    mma_A   A[ntx][8];
+    int   scA[ntx][mma_C::ne/2][8];
+    float  dA[ntx][mma_C::ne/2];
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            const int k0 = k00 + k01;
+
+            A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
+            A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+        }
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) {
+            const int k0 = k00 + k01;
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+                const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
+                const int8_t * sc        = (const int8_t *) &sc_packed;
+
+#pragma unroll
+                for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+                    scA[n][l][k01/4 + ksc] = sc[ksc];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int l = 0; l < mma_C::ne/2; ++l) {
+            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+
+            dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+        float tmp[ntx][mma_C::ne] = {{0.0f}};
+
+#pragma unroll
+        for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
+            mma_B B[2];
+            float dB[mma_C::ne/2];
+
+            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
+            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+            for (int l = 0; l < mma_C::ne/2; ++l) {
+                const int j = j0 + mma_C::get_j(l);
+
+                dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+            }
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                mma_C C[2];
+                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
+                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+                for (int l = 0; l < mma_C::ne; ++l) {
+                    tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
+                }
+            }
+        }
+
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            }
+        }
+    }
+#else
+    GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+    NO_DEVICE_CODE;
+#endif // INT8_MMA_AVAILABLE
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq4_nl(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_NL;
+    const int kqsx = threadIdx.x % QI4_NL;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b2(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4);
+        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
+        int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq2_xxs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_XXS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+
+        const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
+        const uint8_t * aux8 = (const uint8_t *) &q2;
+        const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
+
+#pragma unroll
+        for (int l = 0; l < QR2_XXS; ++l) {
+            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
+            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+
+            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+
+            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = aux32 >> 28;
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq2_xs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_XS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+
+        const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint16_t * q2 = (const uint16_t *) &q2_packed;
+
+    #pragma unroll
+        for (int l = 0; l < QR2_XS; ++l) {
+            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
+            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
+
+            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = bxi->scales[kqsx];
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#else
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq2_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI2_S/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+
+        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+        const int       signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
+        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+        for (int l = 0; l < QR2_S; ++l) {
+            const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
+
+            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+            const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+            const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = bxi->scales[kqsx];
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#else
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
+        x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq3_xxs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI3_XXS/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+
+        const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint8_t * q3 = (const uint8_t *) &q3_packed;
+        const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
+
+#pragma unroll
+        for (int l = 0; l < QR3_XXS; ++l) {
+            const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+
+            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+
+            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = aux32 >> 28;
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq3_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % (QI3_S/2);
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) {
+        int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+
+        const int2      qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+        const int       signs_packed_32 = get_int_b2(bxi->signs, kqsx);
+        const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+        for (int l = 0; l < QR3_S; ++l) {
+            const int2 grid_pos = make_int2(
+                iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
+                iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
+
+            const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+            const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
+        const float d = bxi->d;
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq1_s(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    half2 * x_ds = (half2 *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kqsx = threadIdx.x % QI1_S;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) {
+        int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+
+        const int       qs_packed = get_int_b2(bxi->qs, kqsx);
+        const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+        const int qh = bxi->qh[kqsx];
+
+    #pragma unroll
+        for (int l = 0; l < QR1_S/2; ++l) {
+            const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
+
+            const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+            const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+#ifdef INT8_MMA_AVAILABLE
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
+#else
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
+            x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
+#endif // INT8_MMA_AVAILABLE
+        }
+
+        const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
+        const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#else
+        x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template  static __device__ __forceinline__ void load_tiles_iq4_xs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = 0;           // threadIdx.x / QI4_XS
+    const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4);
+        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+        const float d = __half2float(bxi->d);
+
+        const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
+            | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j > j_max) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            if (need_check && i > i_max) {
+                continue;
+            }
+
+            dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+template
+static __device__ __forceinline__ void mmq_write_back_mma(
+    const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
+
+    typedef mma_int_C_I16J8 mma_C;
+
+    constexpr int granularity = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+
+    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
+#ifdef INT8_MMA_AVAILABLE
+    static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
+#endif // INT8_MMA_AVAILABLE
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+#pragma unroll
+        for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+            for (int l = 0; l < mma_C::ne; ++l) {
+                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+
+                if (j > j_max) {
+                    continue;
+                }
+
+                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+
+                if (need_check && i > i_max) {
+                    continue;
+                }
+
+                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+            }
+        }
+    }
+}
+
+// -------------------------------------------------------------------------------------------------------------------------------------
+
+template 
+struct mmq_type_traits;
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q4_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_0;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q4_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_1;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q5_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_0;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q5_1_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_1;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q8_0_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q8_0;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q2_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q2_K_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q3_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q3_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q4_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q4_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q5_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q5_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_Q6_K_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_q6_K;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q6_K_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ2_XXS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xxs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ2_XS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_xs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ2_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq2_s;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ3_XXS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_xxs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ3_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq3_s;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ1_S_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq1_s;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_1_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+struct mmq_type_traits {
+    static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a;
+};
+
+template 
+static __device__ void mul_mat_q_process_tile(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
+    const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
+
+    constexpr int              qk         = ggml_cuda_type_traits::qk;
+    constexpr int              mmq_y      = get_mmq_y_device();
+    constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles;
+
+    extern __shared__ char data_mul_mat_q[];
+    int * tile_y = (int *) data_mul_mat_q;
+    int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
+
+#ifdef INT8_MMA_AVAILABLE
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_mma;
+    constexpr mmq_write_back_t write_back = mmq_write_back_mma;
+#else
+    constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_dp4a;
+    constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
+#endif // INT8_MMA_AVAILABLE
+
+    constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+    const int tile_x_max_i = ne01 - it*mmq_y - 1;
+    const int tile_y_max_j = ne11 - jt*mmq_x - 1;
+
+    const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
+
+    for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
+        load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
+
+        {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                tile_y[l] = by0[l];
+            }
+        }
+
+        __syncthreads();
+
+        vec_dot(tile_x, tile_y, sum, 0);
+
+        __syncthreads();
+
+        {
+            const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+#pragma unroll
+            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
+                int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
+
+                tile_y[l] = by0[l];
+            }
+        }
+
+        __syncthreads();
+
+        vec_dot(tile_x, tile_y, sum, WARP_SIZE);
+
+        __syncthreads();
+    }
+
+    if (fixup) {
+        write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+    } else {
+        write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
+    }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
+template 
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#else
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+    __launch_bounds__(WARP_SIZE*nwarps, 1)
+#else
+    __launch_bounds__(WARP_SIZE*nwarps, 2)
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+static __global__ void mul_mat_q(
+    const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+    const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
+
+    // Skip unused template specializations for faster compilation:
+    if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    constexpr int qk    = ggml_cuda_type_traits::qk;
+    constexpr int mmq_y = get_mmq_y_device();
+
+    // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+    {
+        constexpr bool fixup = false;
+        mul_mat_q_process_tile
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+                blockIdx.x, blockIdx.y, 0, ne00/qk);
+        return;
+    }
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int64_t kbc      = (int64_t) blockIdx.x     *blocks_per_ne00*ntx*nty / gridDim.x;
+    int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
+
+    kbc      -= (kbc      % blocks_per_ne00) % blocks_per_iter;
+    kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
+
+    // kb0 == k index when doing the matrix multiplication for an output tile.
+    int kb0_start = kbc % blocks_per_ne00;
+    int kb0_stop  = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+        const int jt =  kbc /    (blocks_per_ne00*nty);                    // j index of current tile.
+        const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
+
+        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        mul_mat_q_process_tile
+            (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+             it, jt, kb0_start, kb0_stop);
+
+        kbc += blocks_per_ne00;
+        kbc -= kbc % blocks_per_ne00;
+
+        kb0_start = 0;
+        kb0_stop  = min(blocks_per_ne00, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
+
+    const int jt =  kbc /    (blocks_per_ne00*nty);
+    const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    mul_mat_q_process_tile
+        (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
+            it, jt, kb0_start, kb0_stop);
+}
+
+
+template 
+static __global__ void mul_mat_q_stream_k_fixup(
+    float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
+
+    constexpr int     mmq_y           = get_mmq_y_device();
+    constexpr int     qk              = ggml_cuda_type_traits::qk;
+    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+    const     int64_t blocks_per_ne00 = ne00 / qk;
+
+    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
+
+    const int ntx = (ne11 + mmq_x - 1) / mmq_x;
+    const int nty = (ne01 + mmq_y - 1) / mmq_y;
+
+    bool any_fixup = false;
+
+    const int bidx_start = ((blockIdx.y*nty + blockIdx.x)     * block_num_mmq)                           / (gridDim.y*gridDim.x);
+    const int bidx_stop  = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
+
+    int64_t kbc_0;
+    int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+    for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
+        kbc_0 = kbc_stop_0;
+        kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
+
+        const int64_t kbc      = kbc_0      - (kbc_0      % blocks_per_ne00) % blocks_per_iter;
+        const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
+
+        // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
+        if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
+            continue;
+        }
+
+        const int jt =  kbc_stop /    (blocks_per_ne00*nty);
+        const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
+
+        // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
+        if (it != blockIdx.x || jt != blockIdx.y) {
+            continue;
+        }
+
+        any_fixup = true;
+
+#pragma unroll
+        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+#pragma unroll
+            for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+
+                sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+            }
+        }
+    }
+
+    if (!any_fixup) {
+        return;
+    }
+
+    dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
+
+    const int i_max = ne01 - blockIdx.x*mmq_y - 1;
+    const int j_max = ne11 - blockIdx.y*mmq_x - 1;
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j > j_max) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+
+            if (need_check && i > i_max) {
+                continue;
+            }
+
+            dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
+        }
+    }
+}
+
+struct mmq_args {
+    const char * x; const char * y; float * dst;
+    int64_t ne00; int64_t ne01; int64_t stride01;
+    int64_t ne10; int64_t ne11; int64_t stride11;
+    int64_t ne0;
+    bool use_stream_k;
+};
+
+template
+static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
+    const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+    const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
+    return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
+}
+
+template 
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
+    const int nsm = ggml_cuda_info().devices[id].nsm;
+    const int mmq_y = get_mmq_y_host(cc);
+
+    const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
+
+    const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc);
+
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+    static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+    if (!shmem_limit_raised[id]) {
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q,  cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
+        shmem_limit_raised[id] = true;
+    }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+
+    const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
+    const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
+    const dim3 block_nums_xy_tiling(nty, ntx, 1);
+
+    if (!args.use_stream_k) {
+        if (args.ne01 % mmq_y == 0) {
+            constexpr bool need_check = false;
+            mul_mat_q<<>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        } else {
+            constexpr bool need_check = true;
+            mul_mat_q<<>>
+                (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+        }
+        return;
+    }
+
+    const dim3 block_nums_mmq(nsm, 1, 1);
+
+    ggml_cuda_pool & pool = ctx.pool(id);
+    ggml_cuda_pool_alloc tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
+
+    if (args.ne01 % mmq_y == 0) {
+        constexpr bool need_check = false;
+
+        mul_mat_q<<>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<<>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+    } else {
+        constexpr bool need_check = true;
+
+        mul_mat_q<<>>
+            (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
+
+        mul_mat_q_stream_k_fixup<<>>
+            (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
+    }
+}
+
+template 
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+    const int id    = ggml_cuda_get_device();
+    const int nsm   = ggml_cuda_info().devices[id].nsm;
+    const int cc    = ggml_cuda_info().devices[id].cc;
+    const int smpbo = ggml_cuda_info().devices[id].smpbo;
+
+    const int mmq_x_max = get_mmq_x_max_host(cc);
+    const int mmq_y = get_mmq_y_host(cc);
+    const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
+    const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
+
+    int mmq_x_best  = 0;
+    int nparts_best = INT_MAX;
+
+    for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
+        const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+        if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) {
+            continue;
+        }
+
+        const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
+        const int nwaves_xy_tiling = ntiles_x*block_num_y;
+        const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
+
+        if (nparts < nparts_best) {
+            mmq_x_best  = mmq_x;
+            nparts_best = nparts;
+        }
+    }
+
+    switch (mmq_x_best) {
+        case   8:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  16:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  24:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  32:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  40:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  48:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  56:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  64:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  72:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  80:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  88:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case  96:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case 104:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case 112:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case 120:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        case 128:
+            launch_mul_mat_q(ctx, args, stream);
+            break;
+        default:
+            fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+#define DECL_MMQ_CASE(type)                                                        \
+    template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
+
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
+
+// -------------------------------------------------------------------------------------------------------------------------
+
+void ggml_cuda_op_mul_mat_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11);
diff --git a/llama/ggml-cuda/mmv.cu b/llama/ggml-cuda/mmv.cu
new file mode 100644
index 000000000..37559c742
--- /dev/null
+++ b/llama/ggml-cuda/mmv.cu
@@ -0,0 +1,287 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "mmv.cuh"
+
+template 
+static __global__ void mul_mat_vec(
+        const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
+    const int64_t row     = blockIdx.x;
+    const int64_t channel = blockIdx.z;
+    const int     tid     = threadIdx.x;
+
+    x   += (channel/channel_ratio)*stride_channel_x + row*stride_row;
+    y   +=  channel               *stride_channel_y;
+    dst +=  channel               *stride_channel_dst;
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+    float * buf_iw = (float *) data_mmv;
+
+    if (block_size > WARP_SIZE) {
+        if (tid < WARP_SIZE) {
+            buf_iw[tid] = 0.0f;
+        }
+        __syncthreads();
+    }
+
+    float sumf;
+
+    if constexpr (std::is_same::value) {
+        const half2 * x2 = (const half2 *) x;
+
+        if (std::is_same::value) {
+            sumf = 0.0f;
+
+            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmpx = __half22float2(x2[col2]);
+                const float2 tmpy = y2[col2];
+                sumf += tmpx.x * tmpy.x;
+                sumf += tmpx.y * tmpy.y;
+            }
+        } else {
+#ifdef FP16_AVAILABLE
+            half2 sumh2 = make_half2(0.0f, 0.0f);
+
+            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmp = y2[col2];
+                sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+            }
+
+            sumf = __low2float(sumh2) + __high2float(sumh2);
+#else
+            NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+        }
+    } else if constexpr (std::is_same::value) {
+        const int * x2 = (const int *) x;
+        sumf = 0.0f;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const int    tmpx = x2[col2];
+            const float2 tmpy = y2[col2];
+            sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x;
+            sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y;
+        }
+    } else {
+        static_assert(std::is_same::value, "unsupported type");
+    }
+
+    sumf = warp_reduce_sum(sumf);
+
+    if (block_size > WARP_SIZE) {
+        buf_iw[tid/WARP_SIZE] = sumf;
+        __syncthreads();
+        if (tid >= WARP_SIZE) {
+            return;
+        }
+        sumf = buf_iw[tid];
+        sumf = warp_reduce_sum(sumf);
+    }
+
+    if (tid != 0) {
+        return;
+    }
+
+    dst[row] = sumf;
+}
+
+template 
+static void launch_mul_mat_vec_cuda(
+        const T * x, const float * y, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        cudaStream_t stream) {
+    GGML_ASSERT(ncols      % 2 == 0);
+    GGML_ASSERT(stride_row % 2 == 0);
+    GGML_ASSERT(nchannels_y % nchannels_x == 0);
+    const int64_t channel_ratio = nchannels_y / nchannels_x;
+
+    int64_t block_size_best = WARP_SIZE;
+    int64_t niter_best      = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
+    for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
+        const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+        if (niter < niter_best) {
+            niter_best      = niter;
+            block_size_best = block_size;
+        }
+    }
+
+    const int smem = WARP_SIZE*sizeof(float);
+    const dim3 block_nums(nrows, 1, nchannels_y);
+    const dim3 block_dims(block_size_best, 1, 1);
+    switch (block_size_best) {
+        case   32: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case   64: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case   96: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  128: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  160: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  192: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  224: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        case  256: {
+            mul_mat_vec<<>>
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+template
+static void mul_mat_vec_cuda(
+        const T * x, const float * y, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        enum ggml_prec prec, cudaStream_t stream) {
+    switch (prec) {
+        case GGML_PREC_DEFAULT: {
+            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+        } break;
+        case GGML_PREC_F32: {
+            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+        } break;
+    }
+}
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    GGML_ASSERT(src1->ne[1] == 1);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne12 = src1->ne[2];
+    GGML_ASSERT(dst->ne[2] == ne12);
+
+    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_ASSERT(src1->ne[3] == 1);
+    GGML_ASSERT( dst->ne[3] == 1);
+
+    const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type);
+    const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type);
+    const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
+    const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
+
+    switch (src0->type) {
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+}
+
+void ggml_cuda_op_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    GGML_ASSERT(src1_ncols == 1);
+
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+
+    // ggml_cuda_op provides single, contiguous matrices
+    const int64_t stride_row         = ne00;
+    const int64_t nchannels_x        = 1;
+    const int64_t nchannels_y        = 1;
+    const int64_t channel_stride_x   = 0;
+    const int64_t channel_stride_y   = 0;
+    const int64_t channel_stride_dst = 0;
+
+    switch (src0->type) {
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
+
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddq_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}
diff --git a/llama/ggml-cuda/mmv.cuh b/llama/ggml-cuda/mmv.cuh
new file mode 100644
index 000000000..fcfc8ea46
--- /dev/null
+++ b/llama/ggml-cuda/mmv.cuh
@@ -0,0 +1,38 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
+#define MMV_MAX_ROWS 512
+
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_vec(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/llama/ggml-cuda/mmvq.cu b/llama/ggml-cuda/mmvq.cu
new file mode 100644
index 000000000..19ea9aa96
--- /dev/null
+++ b/llama/ggml-cuda/mmvq.cu
@@ -0,0 +1,451 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "mmvq.cuh"
+#include "vecdotq.cuh"
+
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+
+static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
+        type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
+        type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
+        type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
+        type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
+        type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
+        type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
+        type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
+        type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
+        type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
+        type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
+        type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
+        type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
+        type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
+        type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
+        type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
+        type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
+        type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
+        type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
+        nullptr;
+}
+
+static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+    return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_1    ? VDR_Q4_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_0    ? VDR_Q5_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_1    ? VDR_Q5_1_Q8_1_MMVQ :
+        type == GGML_TYPE_Q8_0    ? VDR_Q8_0_Q8_1_MMVQ :
+        type == GGML_TYPE_Q2_K    ? VDR_Q2_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q3_K    ? VDR_Q3_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q4_K    ? VDR_Q4_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q5_K    ? VDR_Q5_K_Q8_1_MMVQ :
+        type == GGML_TYPE_Q6_K    ? VDR_Q6_K_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_XS  ? VDR_IQ2_XS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ2_S   ? VDR_IQ2_S_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ3_S   ? VDR_IQ3_S_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_NL  ? VDR_IQ4_NL_Q8_1_MMVQ :
+        type == GGML_TYPE_IQ4_XS  ? VDR_IQ4_XS_Q8_1_MMVQ :
+        1;
+}
+
+template 
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+// tell the compiler to use as many registers as it wants, see nwarps definition below
+__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void mul_mat_vec_q(
+    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+
+    constexpr int qk  = ggml_cuda_type_traits::qk;
+    constexpr int qi  = ggml_cuda_type_traits::qi;
+    constexpr int vdr = get_vdr_mmvq(type);
+
+    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+    constexpr int nwarps              = 1;
+    constexpr int rows_per_cuda_block = 1;
+#else
+    constexpr int nwarps              = ncols_y <= 4 ? 4 : 2;
+    constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+
+    const     int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+    const     int row0 = rows_per_cuda_block*blockIdx.x;
+    const     int blocks_per_row_x = ncols_x / qk;
+    const     int blocks_per_col_y = nrows_y / QK8_1;
+    constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
+
+// partial sum for each thread
+    float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
+
+    const block_q8_1 * y = (const block_q8_1 *) vy;
+
+    for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+        const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
+
+        // x block quant index when casting the quants to int
+        const int kqs = vdr * (tid % (qi/vdr));
+
+#pragma unroll
+        for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+            for (int i = 0; i < rows_per_cuda_block; ++i) {
+                tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+            }
+        }
+    }
+
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
+    if (threadIdx.y > 0) {
+#pragma unroll
+        for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+            for (int i = 0; i < rows_per_cuda_block; ++i) {
+                tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
+            }
+        }
+    }
+    __syncthreads();
+    if (threadIdx.y > 0) {
+        return;
+    }
+
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int j = 0; j < ncols_y; ++j) {
+#pragma unroll
+        for (int i = 0; i < rows_per_cuda_block; ++i) {
+#pragma unroll
+            for (int l = 0; l < nwarps-1; ++l) {
+                tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
+            }
+            tmp[j][i] = warp_reduce_sum(tmp[j][i]);
+        }
+
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+            dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
+        }
+    }
+}
+
+template 
+static void mul_mat_vec_q_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
+    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
+
+    int id = ggml_cuda_get_device();
+
+    int64_t nwarps = 1;
+    int64_t rows_per_cuda_block = 1;
+
+    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
+        switch(ncols_y) {
+            case 1:
+                nwarps = 4;
+                rows_per_cuda_block = 1;
+                break;
+            case 2:
+            case 3:
+            case 4:
+                nwarps = 4;
+                rows_per_cuda_block = 2;
+                break;
+            case 5:
+            case 6:
+            case 7:
+            case 8:
+                nwarps = 2;
+                rows_per_cuda_block = 2;
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+    }
+    const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
+    const dim3 block_nums(nblocks, 1, 1);
+    const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+    switch (ncols_y) {
+        case 1:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 2:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 3:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 4:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 5:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 6:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 7:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        case 8:
+            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+static void mul_mat_vec_q4_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q4_1_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_1_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q8_0_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q2_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q3_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q4_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q5_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_q6_K_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_xxs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_xs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq2_s_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq3_xxs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_s_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq1_m_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq4_nl_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq4_xs_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+static void mul_mat_vec_iq3_s_q8_1_cuda(
+    const void * vx, const void * vy, float * dst,
+    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+
+    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+}
+
+void ggml_cuda_op_mul_mat_vec_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t row_diff = row_high - row_low;
+
+    const int64_t ne10 = src1->ne[0];
+    GGML_ASSERT(ne10 % QK8_1 == 0);
+
+    const int64_t ne0 = dst->ne[0];
+
+    int id = ggml_cuda_get_device();
+
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // nrows_dst == nrows of the matrix that the kernel writes into
+    const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ1_M:
+            mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+
+    GGML_UNUSED(src1);
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src1_ddf_i);
+    GGML_UNUSED(src1_ncols);
+    GGML_UNUSED(src1_padded_row_size);
+}
diff --git a/llama/ggml-cuda/mmvq.cuh b/llama/ggml-cuda/mmvq.cuh
new file mode 100644
index 000000000..ae18ae315
--- /dev/null
+++ b/llama/ggml-cuda/mmvq.cuh
@@ -0,0 +1,35 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
+void ggml_cuda_op_mul_mat_vec_q(
+    ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+    const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/llama/ggml-cuda/norm.cu b/llama/ggml-cuda/norm.cu
new file mode 100644
index 000000000..6bc05ff70
--- /dev/null
+++ b/llama/ggml-cuda/norm.cu
@@ -0,0 +1,250 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "norm.cuh"
+
+template 
+static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    float2 mean_var = make_float2(0.f, 0.f);
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[row*ncols + col];
+        mean_var.x += xi;
+        mean_var.y += xi * xi;
+    }
+
+    // sum up partial sums
+    mean_var = warp_reduce_sum(mean_var);
+    if (block_size > WARP_SIZE) {
+        __shared__ float2 s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = mean_var;
+        }
+        __syncthreads();
+        mean_var = s_sum[lane_id];
+        mean_var = warp_reduce_sum(mean_var);
+    }
+
+    const float mean = mean_var.x / ncols;
+    const float var = mean_var.y / ncols - mean * mean;
+    const float inv_std = rsqrtf(var + eps);
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+    }
+}
+
+template 
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+    // blockIdx.x: num_groups idx
+    // threadIdx.x: block_size idx
+    int start = blockIdx.x * group_size;
+    int end = start + group_size;
+
+    start += threadIdx.x;
+
+    if (end >= ne_elements) {
+        end = ne_elements;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += block_size) {
+        tmp += x[j];
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float mean = tmp / group_size;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += block_size) {
+        float xi = x[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    float variance = tmp / group_size;
+    float scale = rsqrtf(variance + eps);
+    for (int j = start; j < end; j += block_size) {
+        dst[j] *= scale;
+    }
+}
+
+template 
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xi = x[row*ncols + col];
+        tmp += xi * xi;
+    }
+
+    // sum up partial sums
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    const float mean = tmp / ncols;
+    const float scale = rsqrtf(mean + eps);
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row*ncols + col] = scale * x[row*ncols + col];
+    }
+}
+
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        norm_f32<<>>(x, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        norm_f32<1024><<>>(x, dst, ncols, eps);
+    }
+}
+
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+    if (group_size < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        group_norm_f32<<>>(x, dst, group_size, ne_elements, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps);
+    }
+}
+
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_f32<<>>(x, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_f32<1024><<>>(x, dst, ncols, eps);
+    }
+}
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+}
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    int num_groups = dst->op_params[0];
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+    group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+}
diff --git a/llama/ggml-cuda/norm.cuh b/llama/ggml-cuda/norm.cuh
new file mode 100644
index 000000000..0902f23ac
--- /dev/null
+++ b/llama/ggml-cuda/norm.cuh
@@ -0,0 +1,33 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/opt-step-adamw.cu b/llama/ggml-cuda/opt-step-adamw.cu
new file mode 100644
index 000000000..4bde5c599
--- /dev/null
+++ b/llama/ggml-cuda/opt-step-adamw.cu
@@ -0,0 +1,104 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-impl.h"
+#include "opt-step-adamw.cuh"
+
+#include 
+
+static __global__ void opt_step_adamw_f32(
+    float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,
+    const float * __restrict__ pars, const int64_t k) {
+
+    const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float alpha  = pars[0];
+    const float beta1  = pars[1];
+    const float beta2  = pars[2];
+    const float eps    = pars[3];
+    const float wd     = pars[4];
+    const float beta1h = pars[5];
+    const float beta2h = pars[6];
+
+    const float gi = g[i];
+    const float gmi = g_m[i]*beta1 +    gi*(1.0f - beta1);
+    const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
+
+    g_m[i] = gmi;
+    g_v[i] = gvi;
+
+    const float mh =       gmi*beta1h;
+    const float vh = sqrtf(gvi*beta2h) + eps;
+
+    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
+}
+
+static void opt_step_adamw_f32_cuda(
+    float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {
+
+    const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+    const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+    opt_step_adamw_f32<<>>(x, g, g_m, g_v, pars, k);
+}
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0         = dst->src[0];
+    const ggml_tensor * src0_grad    = dst->src[1];
+    const ggml_tensor * src0_grad_m  = dst->src[2];
+    const ggml_tensor * src0_grad_v  = dst->src[3];
+    const ggml_tensor * adamw_params = dst->src[4];
+
+    GGML_ASSERT(src0->type         == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad->type    == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_m->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0_grad_v->type  == GGML_TYPE_F32);
+    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
+    GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+    GGML_ASSERT(ggml_is_contiguous(adamw_params));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
+
+    float       * src0_d         = (float       *) src0->data;
+    const float * src0_grad_d    = (const float *) src0_grad->data;
+    float       * src0_grad_m_d  = (float       *) src0_grad_m->data;
+    float       * src0_grad_v_d  = (float       *) src0_grad_v->data;
+    const float * adamw_params_d = (const float *) adamw_params->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t ne = ggml_nelements(src0);
+
+    opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);
+}
diff --git a/llama/ggml-cuda/opt-step-adamw.cuh b/llama/ggml-cuda/opt-step-adamw.cuh
new file mode 100644
index 000000000..b956bf93a
--- /dev/null
+++ b/llama/ggml-cuda/opt-step-adamw.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/out-prod.cu b/llama/ggml-cuda/out-prod.cu
new file mode 100644
index 000000000..fb2cc3838
--- /dev/null
+++ b/llama/ggml-cuda/out-prod.cu
@@ -0,0 +1,77 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "out-prod.cuh"
+
+#include 
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(ne01 == ne11);
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+
+    GGML_ASSERT(ne2 == src0->ne[2]);
+    GGML_ASSERT(ne2 == src1->ne[2]);
+    GGML_ASSERT(ne3 == src0->ne[3]);
+    GGML_ASSERT(ne3 == src1->ne[3]);
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       *  dst_d = (float       *)  dst->data;
+
+    cudaStream_t   stream = ctx.stream();
+    cublasHandle_t handle = ctx.cublas_handle();
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+
+    GGML_ASSERT(ne2 == 1);
+    GGML_ASSERT(ne3 == 1);
+    CUBLAS_CHECK(cublasSetStream(handle, stream));
+
+    const bool src1_T = ggml_is_transposed(src1);
+    const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
+    const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
+    GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));
+
+    CUBLAS_CHECK(
+        cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+                ne0, ne1, ne01,
+                &alpha, src0_d, ne00,
+                        src1_d, ldb,
+                &beta,  dst_d,  ne0));
+}
diff --git a/llama/ggml-cuda/out-prod.cuh b/llama/ggml-cuda/out-prod.cuh
new file mode 100644
index 000000000..4631cd65a
--- /dev/null
+++ b/llama/ggml-cuda/out-prod.cuh
@@ -0,0 +1,29 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/pad.cu b/llama/ggml-cuda/pad.cu
new file mode 100644
index 000000000..aa61c0ada
--- /dev/null
+++ b/llama/ggml-cuda/pad.cu
@@ -0,0 +1,121 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "pad.cuh"
+
+static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+        dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
+    }
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    pad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    pad_f32_cuda(src0_d, dst_d,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}
+
+static __global__ void unpad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
+    // blockIdx.z: idx of ne2*ne3, aka ne02*ne03
+    // blockIdx.y: idx of ne1
+    // blockIDx.x: idx of ne0 / BLOCK_SIZE
+    int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst =
+        nidx +
+        blockIdx.y * ne0 +
+        blockIdx.z * ne0 * gridDim.y;
+    if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+        int offset_src =
+            nidx +
+            blockIdx.y * ne00 +
+            blockIdx.z * ne00 * ne01;
+        dst[offset_dst] = x[offset_src];
+    }
+}
+
+static void unpad_f32_cuda(const float * x, float * dst,
+    const int ne00, const int ne01, const int ne02, const int ne03,
+    const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
+    int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne1, ne2*ne3);
+    unpad_f32<<>>(x, dst, ne0, ne00, ne01, ne02, ne03);
+}
+
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    unpad_f32_cuda(src0_d, dst_d,
+        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
+}
diff --git a/llama/ggml-cuda/pad.cuh b/llama/ggml-cuda/pad.cuh
new file mode 100644
index 000000000..9c23680dc
--- /dev/null
+++ b/llama/ggml-cuda/pad.cuh
@@ -0,0 +1,32 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_PAD_BLOCK_SIZE 256
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/pool2d.cu b/llama/ggml-cuda/pool2d.cu
new file mode 100644
index 000000000..adbf1b551
--- /dev/null
+++ b/llama/ggml-cuda/pool2d.cu
@@ -0,0 +1,120 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "pool2d.cuh"
+
+template 
+static  __global__ void pool2d_nchw_kernel(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const Ti* src, To* dst, const enum ggml_op_pool op) {
+    int idx = threadIdx.x + blockIdx.x * blockDim.x;
+    if (idx >= parallel_elements) {
+        return;
+    }
+
+    const int I_HW = ih * iw;
+    const int O_HW = oh * ow;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / ow;
+    const int cur_ow = idx % O_HW % ow;
+    const Ti* i_ptr = src + nc * I_HW;
+    To* o_ptr = dst + nc * O_HW;
+    const int start_h = cur_oh * sh - ph;
+    const int bh = max(0, start_h);
+    const int eh = min(ih, start_h + kh);
+    const int start_w = cur_ow * sw - pw;
+    const int bw = max(0, start_w);
+    const int ew = min(iw, start_w + kw);
+    const To scale = 1. / (kh * kw);
+    To res = 0;
+
+    switch (op) {
+        case GGML_OP_POOL_AVG: res = 0; break;
+        case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+        default: assert(false);
+    }
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+#if __CUDA_ARCH__ >= 350
+            Ti cur = __ldg(i_ptr + i * iw + j);
+#else
+            Ti cur = i_ptr[i * iw + j];
+#endif
+            switch (op) {
+                case GGML_OP_POOL_AVG: res += cur * scale; break;
+                case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+                default: assert(false);
+            }
+        }
+    }
+    o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
+static void pool2d_nchw_kernel_f32_f32_cuda(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const float * src, float * dst, const enum ggml_op_pool op,
+        cudaStream_t stream) {
+
+    const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+    dim3 block_nums(num_blocks);
+    pool2d_nchw_kernel<<>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
+}
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = static_cast(opts[0]);
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    const int64_t IH = src0->ne[1];
+    const int64_t IW = src0->ne[0];
+
+    const int64_t N = dst->ne[3];
+    const int64_t OC = dst->ne[2];
+    const int64_t OH = dst->ne[1];
+    const int64_t OW = dst->ne[0];
+
+    const int parallel_elements = N * OC * OH * OW;
+
+    pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
+}
diff --git a/llama/ggml-cuda/pool2d.cuh b/llama/ggml-cuda/pool2d.cuh
new file mode 100644
index 000000000..9c0045f87
--- /dev/null
+++ b/llama/ggml-cuda/pool2d.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_POOL2D_BLOCK_SIZE 256
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/quantize.cu b/llama/ggml-cuda/quantize.cu
new file mode 100644
index 000000000..60341bee8
--- /dev/null
+++ b/llama/ggml-cuda/quantize.cu
@@ -0,0 +1,195 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "quantize.cuh"
+#include 
+
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
+    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (ix0 >= kx0_padded) {
+        return;
+    }
+
+    const int64_t ix1 = blockIdx.y;
+
+    const int64_t i_padded = ix1*kx0_padded + ix0;
+
+    block_q8_1 * y = (block_q8_1 *) vy;
+
+    const int64_t ib = i_padded / QK8_1; // block index
+    const int64_t iqs = i_padded % QK8_1; // quant index
+
+    const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
+    float amax = fabsf(xi);
+    float sum = xi;
+
+    amax = warp_reduce_max(amax);
+    sum = warp_reduce_sum(sum);
+
+    const float d = amax / 127;
+    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+    y[ib].qs[iqs] = q;
+
+    if (iqs > 0) {
+        return;
+    }
+
+    reinterpret_cast(y[ib].ds.x) = d;
+    reinterpret_cast(y[ib].ds.y) = sum;
+}
+
+template 
+static __global__ void quantize_mmq_q8_1(
+    const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
+
+    constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+    constexpr int vals_per_sum   = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+    const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
+
+    if (ix0 >= kx0_padded) {
+        return;
+    }
+
+    const float4 * x4 = (const float4 *) x;
+
+    const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
+
+    block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+    const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
+    const int64_t ib  = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y;                   // block index in channel
+    const int64_t iqs = ix0 % (4*QK8_1);                                            // quant index in block
+
+    // Load 4 floats per thread and calculate max. abs. value between them:
+    const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+    float amax = fabsf(xi.x);
+    amax = fmaxf(amax, fabsf(xi.y));
+    amax = fmaxf(amax, fabsf(xi.z));
+    amax = fmaxf(amax, fabsf(xi.w));
+
+    // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+    for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
+        amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
+    }
+
+    float sum;
+    if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+        sum = xi.x + xi.y + xi.z + xi.w;
+
+        // Exchange calculate sum across vals_per_sum/4 threads.
+#pragma unroll
+        for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
+            sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
+        }
+    }
+
+    const float d_inv = 127.0f / amax;
+    char4 q;
+    q.x = roundf(xi.x*d_inv);
+    q.y = roundf(xi.y*d_inv);
+    q.z = roundf(xi.z*d_inv);
+    q.w = roundf(xi.w*d_inv);
+
+    // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+    char4 * yqs4 = (char4 *) y[ib].qs;
+    yqs4[iqs/4] = q;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+        if (iqs % 16 != 0 || iqs >= 96) {
+            return;
+        }
+
+        y[ib].d2s6[2 + iqs/16] = sum;
+
+        if (iqs % 64 != 0) {
+            return;
+        }
+
+        const float d = 1.0f / d_inv;
+
+        y[ib].d2s6[iqs/64] = d;
+
+        return;
+    }
+
+    if (iqs % 32 != 0) {
+        return;
+    }
+
+    const float d = 1.0f / d_inv;
+
+    if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+        y[ib].ds4[iqs/32] = make_half2(d, sum);
+    } else {
+        y[ib].d4[iqs/32]  = d;
+    }
+}
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % QK8_1 == 0);
+
+    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, kx1*channels, 1);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
+    quantize_q8_1<<>>(x, vy, kx0, kx0_padded);
+
+    GGML_UNUSED(type_x);
+}
+
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
+    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+
+    GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+
+    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+    const dim3 num_blocks(block_num_x, kx1, channels);
+    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+    switch (mmq_get_q8_1_ds_layout(type_x)) {
+        case MMQ_Q8_1_DS_LAYOUT_D4:
+            quantize_mmq_q8_1
+                <<>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_DS4:
+            quantize_mmq_q8_1
+                <<>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        case MMQ_Q8_1_DS_LAYOUT_D2S6:
+            quantize_mmq_q8_1
+                <<>>(x, vy, kx0, kx1, kx0_padded);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
diff --git a/llama/ggml-cuda/quantize.cuh b/llama/ggml-cuda/quantize.cuh
new file mode 100644
index 000000000..ee8e2a52b
--- /dev/null
+++ b/llama/ggml-cuda/quantize.cuh
@@ -0,0 +1,50 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "common.cuh"
+#include "mmq.cuh"
+
+#include 
+
+#define CUDA_QUANTIZE_BLOCK_SIZE     256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
+
+typedef void (*quantize_cuda_t)(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
+
+void quantize_row_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
+
+void quantize_mmq_q8_1_cuda(
+    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
+    const ggml_type type_x, cudaStream_t stream);
diff --git a/llama/ggml-cuda/rope.cu b/llama/ggml-cuda/rope.cu
new file mode 100644
index 000000000..fc9f6f2f4
--- /dev/null
+++ b/llama/ggml-cuda/rope.cu
@@ -0,0 +1,524 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "rope.cuh"
+
+struct rope_corr_dims {
+    float v[2];
+};
+
+
+struct mrope_sections {
+    int v[4];
+};
+
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static __device__ void rope_yarn(
+    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
+    float * cos_theta, float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+    }
+    *cos_theta = cosf(theta) * mscale;
+    *sin_theta = sinf(theta) * mscale;
+}
+
+template
+static __global__ void rope_norm(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0;
+    const int i2 = row/p_delta_rows;
+
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + 1];
+
+    dst[i + 0] = x0*cos_theta - x1*sin_theta;
+    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+}
+
+template
+static __global__ void rope_neox(
+    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows;
+
+    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims/2];
+
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template
+static __global__ void rope_multi(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i0 >= n_dims) {
+        const int i = row*ne0 + i0;
+
+        dst[i + 0] = x[i + 0];
+        dst[i + 1] = x[i + 1];
+
+        return;
+    }
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows;
+
+    int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    int sec_w = sections.v[1] + sections.v[0];
+    int sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < sections.v[0]) {
+        theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    }
+    else if (sector >= sections.v[0] && sector < sec_w) {
+        theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
+    }
+    else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+        theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
+    }
+    else if (sector >= sec_w + sections.v[2]) {
+        theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
+    }
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims/2];
+
+    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template
+static __global__ void rope_vision(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+    if (i0 >= ne0) {
+        return;
+    }
+
+    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+    const int i  = row*ne0 + i0/2;
+    const int i2 = row/p_delta_rows; // i2-th tokens
+
+    int sect_dims = sections.v[0] + sections.v[1];
+    int sec_w = sections.v[1] + sections.v[0];
+    int sector = (i0 / 2) % sect_dims;
+
+    float theta_base = 0.0;
+    if (sector < sections.v[0]) {
+        const int p = sector;
+        theta_base = pos[i2]*powf(theta_scale, p);
+    }
+    else if (sector >= sections.v[0] && sector < sec_w) {
+        const int p = sector - sections.v[0];
+        theta_base = pos[i2 + ne2]*powf(theta_scale, p);
+    }
+
+    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+    float cos_theta;
+    float sin_theta;
+
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+
+    const float x0 = x[i + 0];
+    const float x1 = x[i + n_dims];
+
+    dst[i + 0]      = x0*cos_theta - x1*sin_theta;
+    dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
+}
+
+template
+static void rope_norm_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_norm<<>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    } else {
+        rope_norm<<>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    }
+}
+
+template
+static void rope_neox_cuda(
+    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_neox<<>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    } else {
+        rope_neox<<>>(
+                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors
+                );
+    }
+}
+
+template
+static void rope_multi_cuda(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_multi<<>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    } else {
+        rope_multi<<>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    }
+}
+
+template
+static void rope_vision_cuda(
+    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+    GGML_ASSERT(ne0 % 2 == 0);
+    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+    const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
+    const dim3 block_nums(nr, n_blocks_x, 1);
+    // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
+    // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
+
+    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+    if (freq_factors == nullptr) {
+        rope_vision<<>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    } else {
+        rope_vision<<>>(
+                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
+                theta_scale, freq_factors, sections
+                );
+    }
+}
+
+static void rope_norm_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+    rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_norm_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+    rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_f16(
+    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+
+    rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_neox_cuda_f32(
+    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
+) {
+
+    rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
+}
+
+static void rope_multi_cuda_f16(
+    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_multi_cuda_f32(
+    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_vision_cuda_f16(
+    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+static void rope_vision_cuda_f32(
+    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
+    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
+) {
+
+    rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
+}
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    const int64_t ne00 = src0->ne[0]; // head dims
+    const int64_t ne01 = src0->ne[1]; // num heads
+    const int64_t ne02 = src0->ne[2]; // num heads
+    const int64_t nr = ggml_nrows(src0);
+
+    //const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims     = ((int32_t *) dst->op_params)[1];
+    const int mode       = ((int32_t *) dst->op_params)[2];
+    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+    mrope_sections sections;
+
+    // RoPE alteration for extended context
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(§ions.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
+    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (is_mrope) {
+        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne00/2);
+    }
+
+    const int32_t * pos = (const int32_t *) src1_d;
+
+    const float * freq_factors = nullptr;
+    if (src2 != nullptr) {
+        freq_factors = (const float *) src2->data;
+    }
+
+    rope_corr_dims corr_dims;
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
+
+    // compute
+    if (is_neox) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_neox_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_neox_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    } else if (is_mrope && !is_vision) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_multi_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_multi_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    } else if (is_vision) {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_vision_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_vision_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, sections, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    } else {
+        if (src0->type == GGML_TYPE_F32) {
+            rope_norm_cuda_f32(
+                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else if (src0->type == GGML_TYPE_F16) {
+            rope_norm_cuda_f16(
+                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
+                attn_factor, corr_dims, freq_factors, stream
+            );
+        } else {
+            GGML_ABORT("fatal error");
+        }
+    }
+}
diff --git a/llama/ggml-cuda/rope.cuh b/llama/ggml-cuda/rope.cuh
new file mode 100644
index 000000000..cd5140ce0
--- /dev/null
+++ b/llama/ggml-cuda/rope.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_ROPE_BLOCK_SIZE 256
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/scale.cu b/llama/ggml-cuda/scale.cu
new file mode 100644
index 000000000..b3b38cdf3
--- /dev/null
+++ b/llama/ggml-cuda/scale.cu
@@ -0,0 +1,57 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "scale.cuh"
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = scale * x[i];
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+    scale_f32<<>>(x, dst, scale, k);
+}
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float scale;
+    memcpy(&scale, dst->op_params, sizeof(float));
+
+    scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
+}
diff --git a/llama/ggml-cuda/scale.cuh b/llama/ggml-cuda/scale.cuh
new file mode 100644
index 000000000..ae2ec5af5
--- /dev/null
+++ b/llama/ggml-cuda/scale.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_SCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/softmax.cu b/llama/ggml-cuda/softmax.cu
new file mode 100644
index 000000000..52aad62f6
--- /dev/null
+++ b/llama/ggml-cuda/softmax.cu
@@ -0,0 +1,232 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "softmax.cuh"
+
+template 
+static __device__ __forceinline__ float t2f32(T val) {
+    return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32(half val) {
+    return __half2float(val);
+}
+
+template 
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
+
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+    const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
+
+    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+
+    const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
+
+    extern __shared__ float data_soft_max_f32[];
+    float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
+    // shared memory buffer to cache values between iterations:
+    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+
+    float max_val = -INFINITY;
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const int64_t ix = (int64_t)rowx*ncols + col;
+        const int64_t iy = (int64_t)rowy*ncols + col;
+
+        const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+
+        vals[col] = val;
+        max_val = max(max_val, val);
+    }
+
+    // find the max value in the block
+    max_val = warp_reduce_max(max_val);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf_iw[lane_id] = -INFINITY;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf_iw[warp_id] = max_val;
+        }
+        __syncthreads();
+
+        max_val = buf_iw[lane_id];
+        max_val = warp_reduce_max(max_val);
+    }
+
+    float tmp = 0.0f; // partial sum
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            break;
+        }
+
+        const float val = expf(vals[col] - max_val);
+        tmp += val;
+        vals[col] = val;
+    }
+
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __syncthreads();
+        if (warp_id == 0) {
+            buf_iw[lane_id] = 0.0f;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf_iw[warp_id] = tmp;
+        }
+        __syncthreads();
+
+        tmp = buf_iw[lane_id];
+        tmp = warp_reduce_sum(tmp);
+    }
+
+    const float inv_sum = 1.0f / tmp;
+
+#pragma unroll
+    for (int col0 = 0; col0 < ncols; col0 += block_size) {
+        const int col = col0 + tid;
+
+        if (ncols_template == 0 && col >= ncols) {
+            return;
+        }
+
+        const int64_t idst = (int64_t)rowx*ncols + col;
+        dst[idst] = vals[col] * inv_sum;
+    }
+}
+
+template
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const dim3 block_dims(nth,     1, 1);
+    const dim3 block_nums(nrows_x, 1, 1);
+    const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+    const uint32_t n_head      = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+    if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+        switch (ncols_x) {
+            case 32:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 64:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 128:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 256:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 512:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 1024:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 2048:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            case 4096:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+            default:
+                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                break;
+        }
+    } else {
+        const size_t shmem_low = WARP_SIZE*sizeof(float);
+        soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+    }
+}
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    const float * src0_d = (const float *)src0->data;
+    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+    const int64_t ne00    = src0->ne[0];
+    const int64_t nrows_x = ggml_nrows(src0);
+    const int64_t nrows_y = src0->ne[1];
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+    if (use_f16) {
+        const half * src1_dd = (const half *)src1_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    } else {
+        const float * src1_dd = (const float *)src1_d;
+
+        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+    }
+}
diff --git a/llama/ggml-cuda/softmax.cuh b/llama/ggml-cuda/softmax.cuh
new file mode 100644
index 000000000..85459e24e
--- /dev/null
+++ b/llama/ggml-cuda/softmax.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/sum.cu b/llama/ggml-cuda/sum.cu
new file mode 100644
index 000000000..e1f0b86e8
--- /dev/null
+++ b/llama/ggml-cuda/sum.cu
@@ -0,0 +1,71 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+#define USE_CUB
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+
+#ifdef USE_CUB
+#include 
+using namespace cub;
+#endif // USE_CUB
+
+#include "sumrows.cuh"
+#include "sum.cuh"
+
+#include 
+
+void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
+#ifdef USE_CUB
+    size_t tmp_size = 0;
+    DeviceReduce::Sum(nullptr,       tmp_size, x, dst, ne, stream);
+    ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size);
+    DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
+#else
+    // Use (inefficient) sum_rows implementation as a fallback.
+    // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
+    sum_rows_f32_cuda(x, dst, ne, 1, stream);
+    GGML_UNUSED(pool);
+#endif // USE_CUB
+}
+
+void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
+
+    const int64_t ne = ggml_nelements(src0);
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t stream = ctx.stream();
+
+    sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
+}
diff --git a/llama/ggml-cuda/sum.cuh b/llama/ggml-cuda/sum.cuh
new file mode 100644
index 000000000..6883be872
--- /dev/null
+++ b/llama/ggml-cuda/sum.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
+
+void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/sumrows.cu b/llama/ggml-cuda/sumrows.cu
new file mode 100644
index 000000000..fbd3cd874
--- /dev/null
+++ b/llama/ggml-cuda/sumrows.cu
@@ -0,0 +1,65 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "sumrows.cuh"
+
+static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x;
+    const int col = threadIdx.x;
+
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += blockDim.x) {
+        sum += x[row * ncols + i];
+    }
+
+    sum = warp_reduce_sum(sum);
+
+    if (col == 0) {
+        dst[row] = sum;
+    }
+}
+
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows, 1, 1);
+    k_sum_rows_f32<<>>(x, dst, ncols);
+}
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
+}
diff --git a/llama/ggml-cuda/sumrows.cuh b/llama/ggml-cuda/sumrows.cuh
new file mode 100644
index 000000000..204384f51
--- /dev/null
+++ b/llama/ggml-cuda/sumrows.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
new file mode 100644
index 000000000..48cdc8f40
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
new file mode 100644
index 000000000..6aeab0bac
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
new file mode 100644
index 000000000..2d98ef1a9
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
new file mode 100644
index 000000000..7fe280e05
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
new file mode 100644
index 000000000..9835cbfaa
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
new file mode 100644
index 000000000..45ffa2a8f
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
new file mode 100644
index 000000000..592287a86
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
new file mode 100644
index 000000000..fe080a734
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu
new file mode 100644
index 000000000..0580444e6
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu
new file mode 100644
index 000000000..5b2650d8a
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu
new file mode 100644
index 000000000..886ba3956
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu
new file mode 100644
index 000000000..789757a80
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu
new file mode 100644
index 000000000..a4bfe23ff
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu
new file mode 100644
index 000000000..eab22f0d9
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu
new file mode 100644
index 000000000..3301160fb
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu
new file mode 100644
index 000000000..aa37c412e
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu
new file mode 100644
index 000000000..a2dd8d866
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu
new file mode 100644
index 000000000..709c2de08
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu
new file mode 100644
index 000000000..3279dad9c
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu
new file mode 100644
index 000000000..4e112e133
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu
new file mode 100644
index 000000000..8662359bd
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu
new file mode 100644
index 000000000..bc3c70614
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu
new file mode 100644
index 000000000..027c6d944
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu
new file mode 100644
index 000000000..543346290
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu
new file mode 100644
index 000000000..9cdcd1b39
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu
new file mode 100644
index 000000000..258e08b29
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu
new file mode 100644
index 000000000..7c41007a3
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu
new file mode 100644
index 000000000..0296737fe
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu
new file mode 100644
index 000000000..f9fdc1976
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu
new file mode 100644
index 000000000..518c67255
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu
new file mode 100644
index 000000000..dfb36938d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu
new file mode 100644
index 000000000..4ae015114
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu
new file mode 100644
index 000000000..a69a7acb5
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu
new file mode 100644
index 000000000..a46aab8a8
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu
new file mode 100644
index 000000000..3fe4f970a
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu
new file mode 100644
index 000000000..933a5dd7a
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu
new file mode 100644
index 000000000..b051c7d1a
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu
new file mode 100644
index 000000000..3a90aba7e
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu
new file mode 100644
index 000000000..3ddad858a
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu
new file mode 100644
index 000000000..df3ce0a3c
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu
new file mode 100644
index 000000000..49d2666a4
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu
new file mode 100644
index 000000000..531c87c22
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu
new file mode 100644
index 000000000..e747f6e7d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f16.cuh"
+
+DECL_FATTN_VEC_F16_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu
new file mode 100644
index 000000000..d6097d1cb
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu
new file mode 100644
index 000000000..a6bda11fd
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu
new file mode 100644
index 000000000..800ea14f6
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu
new file mode 100644
index 000000000..b3bad6b01
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu
new file mode 100644
index 000000000..6a7127ddf
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu
new file mode 100644
index 000000000..62351c23b
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu
new file mode 100644
index 000000000..1b35f1688
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu
new file mode 100644
index 000000000..5c6256810
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu
new file mode 100644
index 000000000..6f70b7404
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu
new file mode 100644
index 000000000..d91c6f920
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu
new file mode 100644
index 000000000..d206889db
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu
new file mode 100644
index 000000000..ae104a61e
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu
new file mode 100644
index 000000000..ab2c66bec
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu
new file mode 100644
index 000000000..4b55d39f3
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu
new file mode 100644
index 000000000..1c1065ff3
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu
new file mode 100644
index 000000000..b973d1617
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu
new file mode 100644
index 000000000..9b3999e89
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu
new file mode 100644
index 000000000..fc7fde303
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu
new file mode 100644
index 000000000..b1f482722
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu
new file mode 100644
index 000000000..b854659a3
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu
new file mode 100644
index 000000000..35db0d6d6
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu
new file mode 100644
index 000000000..cc76b0fb9
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu
new file mode 100644
index 000000000..ff9e76dd6
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu
new file mode 100644
index 000000000..4b031d981
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu
new file mode 100644
index 000000000..b99bab1e7
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu
new file mode 100644
index 000000000..22e2e6db2
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu
new file mode 100644
index 000000000..95c1984e5
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu
new file mode 100644
index 000000000..65307d393
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu
new file mode 100644
index 000000000..ae0ec146b
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu
new file mode 100644
index 000000000..1f420c1d9
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu
new file mode 100644
index 000000000..1d445af3d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu
new file mode 100644
index 000000000..b3a951dc7
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu
new file mode 100644
index 000000000..804c30b20
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu
new file mode 100644
index 000000000..432928a20
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu
new file mode 100644
index 000000000..409f81b06
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu
new file mode 100644
index 000000000..032dab7ff
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu
new file mode 100644
index 000000000..00014a4f3
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu
new file mode 100644
index 000000000..324572636
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu
new file mode 100644
index 000000000..e7d49c270
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu
new file mode 100644
index 000000000..8d732548b
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu
new file mode 100644
index 000000000..a8e257641
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu
new file mode 100644
index 000000000..dabbcd237
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu
new file mode 100644
index 000000000..cfbae911d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec-f32.cuh"
+
+DECL_FATTN_VEC_F32_CASE(64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
new file mode 100644
index 000000000..b1bdc1e95
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu
@@ -0,0 +1,36 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 16, float);
+DECL_FATTN_WMMA_F16_CASE(80, 16, float);
+DECL_FATTN_WMMA_F16_CASE(96, 16, float);
+DECL_FATTN_WMMA_F16_CASE(112, 16, float);
+DECL_FATTN_WMMA_F16_CASE(128, 16, float);
+DECL_FATTN_WMMA_F16_CASE(256, 16, float);
diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
new file mode 100644
index 000000000..3151d9d67
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu
@@ -0,0 +1,35 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 32, float);
+DECL_FATTN_WMMA_F16_CASE(80, 32, float);
+DECL_FATTN_WMMA_F16_CASE(96, 32, float);
+DECL_FATTN_WMMA_F16_CASE(112, 32, float);
+DECL_FATTN_WMMA_F16_CASE(128, 32, float);
diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
new file mode 100644
index 000000000..eea23df92
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu
@@ -0,0 +1,36 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 16, half);
+DECL_FATTN_WMMA_F16_CASE(80, 16, half);
+DECL_FATTN_WMMA_F16_CASE(96, 16, half);
+DECL_FATTN_WMMA_F16_CASE(112, 16, half);
+DECL_FATTN_WMMA_F16_CASE(128, 16, half);
+DECL_FATTN_WMMA_F16_CASE(256, 16, half);
diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
new file mode 100644
index 000000000..70ba3a53b
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu
@@ -0,0 +1,36 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 32, half);
+DECL_FATTN_WMMA_F16_CASE(80, 32, half);
+DECL_FATTN_WMMA_F16_CASE(96, 32, half);
+DECL_FATTN_WMMA_F16_CASE(112, 32, half);
+DECL_FATTN_WMMA_F16_CASE(128, 32, half);
+DECL_FATTN_WMMA_F16_CASE(256, 32, half);
diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
new file mode 100644
index 000000000..3a8261ab0
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu
@@ -0,0 +1,34 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-wmma-f16.cuh"
+
+DECL_FATTN_WMMA_F16_CASE(64, 8, half);
+DECL_FATTN_WMMA_F16_CASE(96, 8, half);
+DECL_FATTN_WMMA_F16_CASE(128, 8, half);
+DECL_FATTN_WMMA_F16_CASE(256, 8, half);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
new file mode 100644
index 000000000..f39436687
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
new file mode 100644
index 000000000..086ab539f
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
new file mode 100644
index 000000000..6af7aa320
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
new file mode 100644
index 000000000..fc771442a
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
new file mode 100644
index 000000000..5ba22c06d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
new file mode 100644
index 000000000..647be438b
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
new file mode 100644
index 000000000..b8263fa36
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
new file mode 100644
index 000000000..41986b9d6
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu
new file mode 100644
index 000000000..023aec760
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q2_K);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu
new file mode 100644
index 000000000..f8bba904d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q3_K);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu
new file mode 100644
index 000000000..425d7a613
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_0);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu
new file mode 100644
index 000000000..91bafb73f
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_1);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu
new file mode 100644
index 000000000..a0ad396c1
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_K);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu
new file mode 100644
index 000000000..dc1cbd434
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_0);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu
new file mode 100644
index 000000000..cc70a445c
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_1);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu
new file mode 100644
index 000000000..3ff67b9f1
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_K);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu
new file mode 100644
index 000000000..1d1ffee9f
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q6_K);
diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu
new file mode 100644
index 000000000..1a7e0865d
--- /dev/null
+++ b/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q8_0);
diff --git a/llama/ggml-cuda/tsembd.cu b/llama/ggml-cuda/tsembd.cu
new file mode 100644
index 000000000..c60367838
--- /dev/null
+++ b/llama/ggml-cuda/tsembd.cu
@@ -0,0 +1,73 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "tsembd.cuh"
+
+static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
+    // blockIDx.y: idx of timesteps->ne[0]
+    // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
+    int i = blockIdx.y;
+    int j = threadIdx.x + blockIdx.x * blockDim.x;
+    float * embed_data = (float *)((char *)dst +  i*nb1);
+
+    if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
+        embed_data[dim] = 0.f;
+    }
+
+    int half = dim / 2;
+    if (j >= half) {
+        return;
+    }
+
+    float timestep = timesteps[i];
+    float freq = (float)expf(-logf(max_period) * j / half);
+    float arg = timestep * freq;
+    embed_data[j] = cosf(arg);
+    embed_data[j + half] = sinf(arg);
+}
+
+static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
+                                        const int dim, const int max_period, cudaStream_t stream) {
+    int half_ceil = (dim + 1) / 2;
+    int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
+    dim3 gridDim(num_blocks, ne00, 1);
+    timestep_embedding_f32<<>>(x, dst, nb1, dim, max_period);
+}
+
+void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const int dim = dst->op_params[0];
+    const int max_period = dst->op_params[1];
+
+    timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
+}
diff --git a/llama/ggml-cuda/tsembd.cuh b/llama/ggml-cuda/tsembd.cuh
new file mode 100644
index 000000000..629586504
--- /dev/null
+++ b/llama/ggml-cuda/tsembd.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
+
+void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/unary.cu b/llama/ggml-cuda/unary.cu
new file mode 100644
index 000000000..e20cba020
--- /dev/null
+++ b/llama/ggml-cuda/unary.cu
@@ -0,0 +1,482 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "unary.cuh"
+
+static __global__ void neg_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = -x[i];
+}
+
+static __global__ void step_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] > 0.0f;
+}
+
+static __global__ void gelu_f32(const float * x, float * dst, const int k) {
+    const float GELU_COEF_A    = 0.044715f;
+    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    float xi = x[i];
+    dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
+}
+
+static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
+    const float GELU_QUICK_COEF = -1.702f;
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+}
+
+static __global__ void silu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] / (1.0f + expf(-x[i]));
+}
+
+static __global__ void tanh_f32(const float * x, float * dst, int k) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = tanhf(x[i]);
+}
+
+static __global__ void relu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0);
+}
+
+static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = 1.0f / (1.0f + expf(-x[i]));
+}
+
+static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+static __global__ void exp_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = expf(x[i]);
+}
+
+static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
+    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
+    if (i >= k) {
+        return;
+    }
+    dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+}
+
+static __global__ void sqr_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * x[i];
+}
+
+static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sqrtf(x[i]);
+}
+
+static __global__ void sin_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sinf(x[i]);
+}
+
+static __global__ void cos_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = cosf(x[i]);
+}
+
+static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
+    neg_f32<<>>(x, dst, k);
+}
+
+static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE;
+    step_f32<<>>(x, dst, k);
+}
+
+static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+    gelu_f32<<>>(x, dst, k);
+}
+
+static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+    gelu_quick_f32<<>>(x, dst, k);
+}
+
+static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_f32<<>>(x, dst, k);
+}
+
+static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
+    tanh_f32<<>>(x, dst, k);
+}
+
+static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    relu_f32<<>>(x, dst, k);
+}
+
+static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
+    sigmoid_f32<<>>(x, dst, k);
+}
+
+static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
+    hardsigmoid_f32<<>>(x, dst, k);
+}
+
+static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
+    hardswish_f32<<>>(x, dst, k);
+}
+
+static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE;
+    exp_f32<<>>(x, dst, k);
+}
+
+static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+    leaky_relu_f32<<>>(x, dst, k, negative_slope);
+}
+
+static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
+    sqr_f32<<>>(x, dst, k);
+}
+
+static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
+    sqrt_f32<<>>(x, dst, k);
+}
+
+static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
+    sin_f32<<>>(x, dst, k);
+}
+
+static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
+    cos_f32<<>>(x, dst, k);
+}
+
+void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
+}
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
diff --git a/llama/ggml-cuda/unary.cuh b/llama/ggml-cuda/unary.cuh
new file mode 100644
index 000000000..3a9161bf9
--- /dev/null
+++ b/llama/ggml-cuda/unary.cuh
@@ -0,0 +1,74 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_NEG_BLOCK_SIZE 256
+#define CUDA_STEP_BLOCK_SIZE 256
+#define CUDA_GELU_BLOCK_SIZE 256
+#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
+#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_EXP_BLOCK_SIZE 256
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
+#define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_SIN_BLOCK_SIZE 256
+#define CUDA_COS_BLOCK_SIZE 256
+
+void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/upscale.cu b/llama/ggml-cuda/upscale.cu
new file mode 100644
index 000000000..19c8f2a18
--- /dev/null
+++ b/llama/ggml-cuda/upscale.cu
@@ -0,0 +1,77 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "upscale.cuh"
+
+static __global__ void upscale_f32(const float * x, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne10, const int ne11, const int ne12, const int ne13,
+        const float sf0, const float sf1, const float sf2, const float sf3) {
+    int index = threadIdx.x + blockIdx.x * blockDim.x;
+    if (index >= ne10 * ne11 * ne12 * ne13) {
+        return;
+    }
+
+    int i10 = index % ne10;
+    int i11 = (index / ne10) % ne11;
+    int i12 = (index / (ne10 * ne11)) % ne12;
+    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+    int i00 = i10 / sf0;
+    int i01 = i11 / sf1;
+    int i02 = i12 / sf2;
+    int i03 = i13 / sf3;
+
+    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+}
+
+static void upscale_f32_cuda(const float * x, float * dst,
+        const int nb00, const int nb01, const int nb02, const int nb03,
+        const int ne10, const int ne11, const int ne12, const int ne13,
+        const float sf0, const float sf1, const float sf2, const float sf3,
+        cudaStream_t stream) {
+    int dst_size = ne10 * ne11 * ne12 * ne13;
+    int num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+    upscale_f32<<>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
+}
+
+void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const float sf0 = (float)dst->ne[0]/src0->ne[0];
+    const float sf1 = (float)dst->ne[1]/src0->ne[1];
+    const float sf2 = (float)dst->ne[2]/src0->ne[2];
+    const float sf3 = (float)dst->ne[3]/src0->ne[3];
+
+    upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
+}
diff --git a/llama/ggml-cuda/upscale.cuh b/llama/ggml-cuda/upscale.cuh
new file mode 100644
index 000000000..d8bb2ec8d
--- /dev/null
+++ b/llama/ggml-cuda/upscale.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-cuda/vecdotq.cuh b/llama/ggml-cuda/vecdotq.cuh
new file mode 100644
index 000000000..43719cbd7
--- /dev/null
+++ b/llama/ggml-cuda/vecdotq.cuh
@@ -0,0 +1,1159 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include 
+
+static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
+    const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
+
+    int x32  = x16[2*i32 + 0] <<  0;
+    x32     |= x16[2*i32 + 1] << 16;
+
+    return x32;
+}
+
+static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
+    return ((const int *) x)[i32]; // assume at least 4 byte alignment
+}
+
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ  4
+
+template  static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+    const int * v, const int * u, const float & d4, const half2 & ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+    }
+
+    const float2 ds8f = __half22float2(ds8);
+
+    // second part effectively subtracts 8 from each quant value
+    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ  4
+
+template  static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+        const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+        // SIMD dot product of quantized values
+        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+    }
+
+#ifdef GGML_CUDA_F16
+    const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+    const float d4d8 = tmp.x;
+    const float m4s8 = tmp.y;
+#else
+    const float2 dm4f = __half22float2(dm4);
+    const float2 ds8f = __half22float2(ds8);
+    const float d4d8 = dm4f.x * ds8f.x;
+    const float m4s8 = dm4f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ  4
+
+template  static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+    const float2 ds8f = __half22float2(ds8);
+
+    // second part effectively subtracts 16 from each quant value
+    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ  4
+
+template  static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+    const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        int vi0 = (vl[i] >>  0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+        vi0    |= (vh[i] <<  4) & 0x00000010; // 0 ->  4
+        vi0    |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+        vi0    |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+        vi0    |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+        sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+        int vi1 = (vl[i] >>  4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+        vi1    |= (vh[i] >> 12) & 0x00000010; // 16 ->  4
+        vi1    |= (vh[i] >>  5) & 0x00001000; // 17 -> 12
+        vi1    |= (vh[i] <<  2) & 0x00100000; // 18 -> 20
+        vi1    |= (vh[i] <<  9) & 0x10000000; // 19 -> 28
+        sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+    }
+
+#ifdef GGML_CUDA_F16
+    const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+    const float d5d8 = tmp.x;
+    const float m5s8 = tmp.y;
+#else
+    const float2 dm5f = __half22float2(dm5);
+    const float2 ds8f = __half22float2(ds8);
+    const float d5d8 = dm5f.x * ds8f.x;
+    const float m5s8 = dm5f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template  static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
+    const int * v, const int * u, const T & d8_0, const T & d8_1) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+    }
+
+    return d8_0*d8_1 * ((T) sumi);
+}
+
+template  static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+    const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+    }
+
+#ifdef GGML_CUDA_F16
+    const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+    const float d8d8 = tmp.x;
+    const float m8s8 = tmp.y;
+#else
+    const float2 dm8f = __half22float2(dm8);
+    const float2 ds8f = __half22float2(ds8);
+    const float d8d8 = dm8f.x * ds8f.x;
+    const float m8s8 = dm8f.y * ds8f.y;
+#endif // GGML_CUDA_F16
+
+    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+template  static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
+    const int * v, const int * u, const float * d8_0, const float & d8_1) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
+        int sumi = 0;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_0/2; ++i) {
+            // SIMD dot product of quantized values
+            sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+        }
+
+        sumf += d8_0[i0/(QI8_0/2)]*sumi;
+    }
+
+    return d8_1*sumf;
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ  4
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+    const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const half2 & dm2, const float * __restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++i) {
+        const int sc = scales[2*i];
+
+        const int vi = (v >> (2*i)) & 0x03030303;
+
+        sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+        // fill int with 4x m
+        int m = sc >> 4;
+        m |= m <<  8;
+        m |= m << 16;
+        sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+    }
+
+    const float2 dm2f = __half22float2(dm2);
+
+    return dm2f.x*sumf_d - dm2f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+template 
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
+
+    float sumf    = 0.0f;
+    float sumf_d8 = 0.0f;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
+        const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
+        int sumi_d0 = 0;
+
+        const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
+        int sumi_d1 = 0;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
+        }
+        sumf_d8 += dm2f0.x * sumi_d0;
+
+#pragma unroll
+        for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+            sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
+        }
+        sumf_d8 += dm2f1.x * sumi_d1;
+
+        if (i0/QI8_1 < ns8) {
+            const float2 s8f = __half22float2(s8[i0/QI8_1]);
+            sumf -= dm2f0.y*s8f.x;
+            sumf -= dm2f1.y*s8f.y;
+        } else {
+            int sumi_m0 = 0;
+#pragma unroll
+            for (int i = i0; i < i0 + QI8_1/2; ++i) {
+                sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
+            }
+            sumf_d8 -= dm2f0.y * sumi_m0;
+
+            int sumi_m1 = 0;
+#pragma unroll
+            for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+                sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
+            }
+            sumf_d8 -= dm2f1.y * sumi_m1;
+        }
+    }
+
+    return sumf + d8*sumf_d8;
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ  2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+    const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        const int isc = scale_offset + 2*i;
+
+        const int isc_low = isc % (QK_K/32);
+        const int sc_shift_low = 4 * (isc / (QK_K/32));
+        const int sc_low  = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+        const int isc_high = isc % (QK_K/64);
+        const int sc_shift_high = 2 * (isc / (QK_K/64));
+        const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+        const int sc = (sc_low | sc_high) - 32;
+
+        const int vil = (vl >> (2*i)) & 0x03030303;
+
+        const int vih = ((vh >> i) << 2) & 0x04040404;
+
+        const int vi = __vsubss4(vil, vih);
+
+        sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d3 * sumf;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d3, const float & d8) {
+
+    int sumi = 0;
+
+#pragma unroll
+    for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+        int sumi_sc = 0;
+
+#pragma unroll
+        for (int i = i0; i < i0 + QI8_1/2; ++i) {
+            sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+        }
+
+        sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+    }
+
+    return d3*d8 * sumi;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K; ++i) {
+        const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+        const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+        const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);  // multiply constant part of q4_K with sum of q8_1 values
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i]);
+
+        sumf_d += ds8f.x * (sc[i] * sumi_d);
+        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+    const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+        const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+        const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+        const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+        const int v0i = vl0i | vh0i;
+        const int v1i = vl1i | vh1i;
+
+        const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+        const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+        sumf_d += d8[i] * (dot1 * sc[i]);
+        sumf_m += d8[i] * (dot2 * m[i]);
+
+    }
+
+    const float2 dm5f = __half22float2(dm5);
+
+    return dm5f.x*sumf_d - dm5f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+    const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+    float sumf_d = 0.0f;
+    float sumf_m = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+        int sumi_d = 0;
+
+#pragma unroll
+        for (int j = 0; j < QI8_1; ++j) {
+            sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
+        }
+
+        const float2 ds8f = __half22float2(ds8[i]);
+
+        sumf_d += ds8f.x * (sc[i] * sumi_d);
+        sumf_m += ds8f.y *   m[i]; // sum of q8_1 block * q4_K min val
+    }
+
+    const float2 dm4f = __half22float2(dm4);
+
+    return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ  8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+    const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+    const float & d, const float * __restrict__ d8) {
+
+    float sumf = 0.0f;
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        const int sc = scales[4*i];
+
+        const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+        const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+        const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+        sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+    }
+
+    return d*sumf;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+    const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+    const float & d6, const float * __restrict__ d8) {
+
+    float sumf_d = 0.0f;
+
+    const int      sc_packed = get_int_b4(sc, 0);
+    const int8_t * sc_reg    = (const int8_t *) &sc_packed;
+
+#pragma unroll
+    for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+        int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+        for (int i = i0; i < i0 + 2; ++i) {
+            sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+            sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+            sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+            sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+        }
+
+        sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
+    }
+
+    return d6 * sumf_d;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
+
+    int v[VDR_Q4_0_Q8_1_MMVQ];
+    int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+        v[i]     = get_int_b2(bq4_0->qs, iqs + i);
+        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0);
+    }
+
+    return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds);
+}
+
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
+
+    int v[VDR_Q4_1_Q8_1_MMVQ];
+    int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+        v[i]     = get_int_b4(bq4_1->qs, iqs + i);
+        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1);
+    }
+
+    return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
+
+    int vl[VDR_Q5_0_Q8_1_MMVQ];
+    int vh[VDR_Q5_0_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+        vl[i]    = get_int_b2(bq5_0->qs, iqs + i);
+        vh[i]    = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0);
+    }
+
+    return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
+
+    int vl[VDR_Q5_1_Q8_1_MMVQ];
+    int vh[VDR_Q5_1_Q8_1_MMVQ];
+    int  u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+        vl[i]    = get_int_b4(bq5_1->qs, iqs + i);
+        vh[i]    = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i));
+        u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+        u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1);
+    }
+
+    return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
+
+    int v[VDR_Q8_0_Q8_1_MMVQ];
+    int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+    for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+        v[i] = get_int_b2(bq8_0->qs, iqs + i);
+        u[i] = get_int_b4(bq8_1->qs, iqs + i);
+    }
+
+    return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds));
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
+
+    const int bq8_offset = QR2_K * (iqs / QI8_1);
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const uint8_t * scales = bq2_K->scales + scale_offset;
+
+    const int v = get_int_b4(bq2_K->qs, iqs);
+    int    u[QR2_K];
+    float d8[QR2_K];
+
+#pragma unroll
+    for (int i = 0; i < QR2_K; ++ i) {
+        u[i]  = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+    }
+
+    return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
+
+    const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+    const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+    const float d = bq3_K->d;
+
+    const int vl = get_int_b2(bq3_K->qs, iqs);
+
+    // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+    const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+    int    u[QR3_K];
+    float d8[QR3_K];
+
+#pragma unroll
+    for (int i = 0; i < QR3_K; ++i) {
+        u[i]  = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+        d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+    }
+
+    return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
+
+    int    v[2];
+    int    u[2*QR4_K];
+    float d8[QR4_K];
+
+    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    v[0] = q4[0];
+    v[1] = q4[4];
+
+    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = __low2float(bq8i->ds);
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
+
+    int   vl[2];
+    int   vh[2];
+    int    u[2*QR5_K];
+    float d8[QR5_K];
+
+    const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+    const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+    const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+    vl[0] = ql[0];
+    vl[1] = ql[4];
+
+    vh[0] = qh[0] >> bq8_offset;
+    vh[1] = qh[4] >> bq8_offset;
+
+    const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+    uint16_t aux[2];
+    const int j = bq8_offset/2;
+    if (j < 2) {
+        aux[0] = scales[j+0] & 0x3f3f;
+        aux[1] = scales[j+2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+    }
+    const uint8_t * sc = (const uint8_t *)aux;
+    const uint8_t * m  = sc + 2;
+
+#pragma unroll
+    for (int i = 0; i < QR5_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i] = __low2float(bq8i->ds);
+
+        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+        u[2*i+0] = q8[0];
+        u[2*i+1] = q8[4];
+    }
+
+    return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
+
+    const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+    const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+    const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+    const int vl = get_int_b2(bq6_K->ql, iqs);
+    const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+    const int8_t * scales = bq6_K->scales + scale_offset;
+
+    int    u[QR6_K];
+    float d8[QR6_K];
+
+#pragma unroll
+    for (int i = 0; i < QR6_K; ++i) {
+        u[i]  = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+        d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
+    }
+
+    return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
+
+#define VDR_IQ2_XXS_Q8_1_MMVQ 2
+#define VDR_IQ2_XXS_Q8_1_MMQ  2
+
+static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
+
+    const int q2 = get_int_b2(bq2->qs, iqs);
+    const uint8_t * aux8 = (const uint8_t *) &q2;
+    const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1);
+
+    int sumi = 0;
+#pragma unroll
+    for (int k0 = 0; k0 < 8; k0 += 2) {
+        const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
+        const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
+
+        const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+        const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+        const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
+        sumi = ggml_cuda_dp4a(grid0, u0, sumi);
+
+        const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+        const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+        const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
+        sumi = ggml_cuda_dp4a(grid1, u1, sumi);
+    }
+
+    const int ls = aux32 >> 28;
+    sumi = (ls*sumi + sumi/2)/4;
+    const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+    return d * sumi;
+}
+
+#define VDR_IQ2_XS_Q8_1_MMVQ 2
+#define VDR_IQ2_XS_Q8_1_MMQ  2
+
+static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
+
+    const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1));
+    const uint16_t * q2 = (const uint16_t *) &q2_packed;
+    const int ls0 = bq2->scales[iqs/2] & 0x0F;
+    const int ls1 = bq2->scales[iqs/2] >> 4;
+
+    int sumi0 = 0;
+    int sumi1 = 0;
+#pragma unroll
+    for (int l0 = 0; l0 < 8; l0 += 2) {
+        const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
+        const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l0/2] >> 9));
+
+        const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+        const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+        if (l0 < 4) {
+            sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
+            sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
+        } else {
+            sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
+            sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
+        }
+    }
+    const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
+    const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+    return d * sumi;
+}
+
+#define VDR_IQ2_S_Q8_1_MMVQ 2
+#define VDR_IQ2_S_Q8_1_MMQ  2
+
+static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
+
+    const int       qs_packed = get_int_b2(bq2->qs, iqs/2);
+    const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+    const int qh = bq2->qh[iqs/2];
+
+    const int       signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2);
+    const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+    const int ls0 = bq2->scales[iqs/2] & 0x0F;
+    const int ls1 = bq2->scales[iqs/2] >> 4;
+
+    int sumi0 = 0;
+    int sumi1 = 0;
+#pragma unroll
+    for (int l0 = 0; l0 < 8; l0 += 2) {
+        const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300)));
+
+        const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
+        const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
+
+        const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+        const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+        if (l0 < 4) {
+            sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
+            sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
+        } else {
+            sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
+            sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
+        }
+    }
+    const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
+
+    const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+    return d * sumi;
+}
+
+#define VDR_IQ3_XXS_Q8_1_MMVQ 2
+#define VDR_IQ3_XXS_Q8_1_MMQ  2
+
+static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx;
+
+    const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1));
+    const uint8_t * q3 = (const uint8_t *) &q3_packed;
+    const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2);
+
+    int sumi = 0;
+#pragma unroll
+    for (int l0 = 0; l0 < 8; l0 += 2) {
+        const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
+
+        const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
+
+        const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+        const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+        sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+        sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
+    }
+
+    const int ls = aux32 >> 28;
+    sumi = (ls*sumi + sumi/2)/2;
+    const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
+    return d * sumi;
+}
+
+#define VDR_IQ3_S_Q8_1_MMVQ 2
+#define VDR_IQ3_S_Q8_1_MMQ  2
+
+// TODO: don't use lookup table for signs
+static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx;
+
+    const int2      qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1));
+    const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+    const int qh = bq3->qh[iqs/2];
+
+    const int       signs_packed_32 = get_int_b2(bq3->signs, iqs/2);
+    const uint8_t * signs_packed_8  = (const uint8_t *) &signs_packed_32;
+
+    int sumi = 0;
+#pragma unroll
+    for (int l0 = 0; l0 < 8; l0 += 2) {
+        const int2 grid_pos = make_int2(
+            iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)],
+            iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]);
+
+        const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
+        const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
+
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+        const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+        const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+        sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+        sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
+    }
+
+    sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);
+
+    const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
+    return d * sumi;
+}
+
+#define VDR_IQ1_S_Q8_1_MMVQ 1
+#define VDR_IQ1_S_Q8_1_MMQ  1
+
+static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+    const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
+
+    const int       qs_packed = get_int_b2(bq1->qs, iqs);
+    const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+    const int qh = bq1->qh[iqs];
+
+    int sumi = 0;
+#pragma unroll
+    for (int l0 = 0; l0 < 8; l0 += 2) {
+        const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
+
+        const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+        const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+        const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
+        const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
+
+        sumi = ggml_cuda_dp4a(grid0, u0, sumi);
+        sumi = ggml_cuda_dp4a(grid1, u1, sumi);
+    }
+
+    const float  d1q   = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
+    const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+    const float2 ds    = __half22float2(bq8_1[iqs].ds);
+    return d1q * (ds.x*sumi + ds.y*delta);
+}
+
+#define VDR_IQ1_M_Q8_1_MMVQ 1
+#define VDR_IQ1_M_Q8_1_MMQ  1
+
+static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
+
+    const int       qs_packed = get_int_b4(bq1->qs, iqs);
+    const uint8_t * qs        = (const uint8_t *) &qs_packed;
+
+    int   sumi[2] = {0};
+    float sumf[2] = {0.0f};
+#pragma unroll
+    for (int l0 = 0; l0 < 8; l0 += 2) {
+        const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
+
+        const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
+
+        const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+        const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+        const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
+        const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
+
+        sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]);
+        sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]);
+
+        const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
+        int sumy = 0;
+        sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy);
+        sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy);
+        sumf[l0/4] += delta*sumy;
+    }
+
+    const uint16_t * sc = (const uint16_t *) bq1->scales;
+
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
+    const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
+
+    const int tmp = sc[iqs/2] >> (6*(iqs%2));
+    const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
+    const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
+    return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
+}
+
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
+    const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;
+    const int8_t * q0_8   = (const int8_t *) &q0_32;
+    const char4    val0_8 = make_char4(
+        kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
+
+    const int      q1_32  = (q4 >> 4) & 0x0F0F0F0F;
+    const int8_t * q1_8   = (const int8_t *) &q1_32;
+    const char4    val1_8 = make_char4(
+        kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
+
+    return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
+#define VDR_IQ4_NL_Q8_1_MMVQ 2
+#define VDR_IQ4_NL_Q8_1_MMQ  4
+
+static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx;
+
+    const int * q8 = (const int *) bq8_1->qs + iqs;
+
+    int sumi = 0;
+#pragma unroll
+    for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+        const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
+        const int2 v = get_int_from_table_16(aux_q4);
+
+        sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+        sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+    }
+
+    const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);
+    return d * sumi;
+}
+
+#define VDR_IQ4_XS_Q8_1_MMVQ 4
+#define VDR_IQ4_XS_Q8_1_MMQ  4
+
+static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
+
+    int sumi = 0;
+#pragma unroll
+    for (int j = 0; j < 4; ++j) {
+        const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
+        const int2 v = get_int_from_table_16(aux_q4);
+
+        const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
+        const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
+
+        sumi = ggml_cuda_dp4a(v.x, u0, sumi);
+        sumi = ggml_cuda_dp4a(v.y, u1, sumi);
+    }
+
+    const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);
+    sumi *= ls - 32;
+
+    const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);
+    return d * sumi;
+}
diff --git a/llama/ggml-cuda/vendors/cuda.h b/llama/ggml-cuda/vendors/cuda.h
new file mode 100644
index 000000000..e309dd3f1
--- /dev/null
+++ b/llama/ggml-cuda/vendors/cuda.h
@@ -0,0 +1,41 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#if CUDART_VERSION < 11020
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#endif // CUDART_VERSION < 11020
diff --git a/llama/ggml-cuda/vendors/hip.h b/llama/ggml-cuda/vendors/hip.h
new file mode 100644
index 000000000..7b3102f39
--- /dev/null
+++ b/llama/ggml-cuda/vendors/hip.h
@@ -0,0 +1,215 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#ifdef __HIP_PLATFORM_AMD__
+// for rocblas_initialize()
+#include "rocblas/rocblas.h"
+#endif // __HIP_PLATFORM_AMD__
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F  HIPBLAS_R_16F
+#define CUDA_R_32F  HIPBLAS_R_32F
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
+#define cublasCreate hipblasCreate
+#define cublasDestroy hipblasDestroy
+#define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cublasOperation_t hipblasOperation_t
+#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEventSynchronize hipEventSynchronize
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaHostRegister hipHostRegister
+#define cudaHostRegisterPortable hipHostRegisterPortable
+#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
+#define cudaHostUnregister hipHostUnregister
+#define cudaLaunchHostFunc hipLaunchHostFunc
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaMemsetAsync hipMemsetAsync
+#define cudaMemGetInfo hipMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamDestroy hipStreamDestroy
+#define cudaStreamFireAndForget hipStreamFireAndForget
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamPerThread hipStreamPerThread
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#define __trap() do { abort(); __builtin_unreachable(); } while(0)
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+
+#define __CUDA_ARCH__ 1300
+
+#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
+#define GCN
+#endif
+
+#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
+#define CDNA
+#endif
+
+#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
+    defined(__gfx1150__) || defined(__gfx1151__)
+#define RDNA3
+#endif
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
+#if defined(__gfx1010__) || defined(__gfx1012__)
+#define RDNA1
+#endif
+
+#ifndef __has_builtin
+    #define __has_builtin(x) 0
+#endif
+
+typedef hip_bfloat16 nv_bfloat16;
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+    const int8x4_t va = reinterpret_cast(a);
+    const int8x4_t vb = reinterpret_cast(b);
+#if __has_builtin(__builtin_elementwise_sub_sat)
+    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+    return reinterpret_cast(c);
+#else
+    int8x4_t c;
+    int16_t tmp;
+#pragma unroll
+    for (int i = 0; i < 4; i++) {
+        tmp = va[i] - vb[i];
+        if(tmp > std::numeric_limits::max()) tmp = std::numeric_limits::max();
+        if(tmp < std::numeric_limits::min()) tmp = std::numeric_limits::min();
+        c[i] = tmp;
+    }
+    return reinterpret_cast(c);
+#endif // __has_builtin(__builtin_elementwise_sub_sat)
+}
+
+static __device__ __forceinline__ int __vsub4(const int a, const int b) {
+    return __vsubss4(a, b);
+}
+
+static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast(a);
+    const uint8x4_t& vb = reinterpret_cast(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+    }
+    return c;
+}
+
+static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
+    const uint8x4_t& va = reinterpret_cast(a);
+    const uint8x4_t& vb = reinterpret_cast(b);
+    unsigned int c;
+    uint8x4_t& vc = reinterpret_cast(c);
+#pragma unroll
+    for (int i = 0; i < 4; ++i) {
+        vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
+    }
+    return c;
+}
+
+#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+// __shfl_xor() for half2 was added in ROCm 5.6
+static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
+    typedef union half2_b32 {
+        half2 val;
+        int   b32;
+    } half2_b32_t;
+    half2_b32_t tmp;
+    tmp.val = var;
+    tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
+    return tmp.val;
+}
+#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
diff --git a/llama/ggml-cuda/vendors/musa.h b/llama/ggml-cuda/vendors/musa.h
new file mode 100644
index 000000000..7b1a4ac41
--- /dev/null
+++ b/llama/ggml-cuda/vendors/musa.h
@@ -0,0 +1,163 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N MUBLAS_OP_N
+#define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
+#define CUDA_R_16F  MUSA_R_16F
+#define CUDA_R_32F  MUSA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#define cublasCreate mublasCreate
+#define cublasDestroy mublasDestroy
+#define cublasGemmEx mublasGemmEx
+#define cublasGemmBatchedEx mublasGemmBatchedEx
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
+#define cublasHandle_t mublasHandle_t
+#define cublasSetMathMode mublasSetMathMode
+#define cublasSetStream mublasSetStream
+#define cublasSgemm mublasSgemm
+#define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
+#define cublasGetStatusString mublasStatus_to_string
+#define cudaDataType_t musaDataType_t
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
+#define cudaDeviceProp musaDeviceProp
+#define cudaDeviceSynchronize musaDeviceSynchronize
+#define cudaError_t musaError_t
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
+#define cudaEventDisableTiming musaEventDisableTiming
+#define cudaEventRecord musaEventRecord
+#define cudaEventSynchronize musaEventSynchronize
+#define cudaEvent_t musaEvent_t
+#define cudaEventDestroy musaEventDestroy
+#define cudaFree musaFree
+#define cudaFreeHost musaFreeHost
+#define cudaGetDevice musaGetDevice
+#define cudaGetDeviceCount musaGetDeviceCount
+#define cudaGetDeviceProperties musaGetDeviceProperties
+#define cudaGetErrorString musaGetErrorString
+#define cudaGetLastError musaGetLastError
+#define cudaHostRegister musaHostRegister
+#define cudaHostRegisterPortable musaHostRegisterPortable
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
+#define cudaHostUnregister musaHostUnregister
+#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaMalloc musaMalloc
+#define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
+#define cudaMemcpy musaMemcpy
+#define cudaMemcpyAsync musaMemcpyAsync
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
+#define cudaMemcpyKind musaMemcpyKind
+#define cudaMemset musaMemset
+#define cudaMemsetAsync musaMemsetAsync
+#define cudaMemGetInfo musaMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
+#define cudaSetDevice musaSetDevice
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
+#define cudaStreamDestroy musaStreamDestroy
+#define cudaStreamFireAndForget musaStreamFireAndForget
+#define cudaStreamNonBlocking musaStreamNonBlocking
+#define cudaStreamPerThread musaStreamPerThread
+#define cudaStreamSynchronize musaStreamSynchronize
+#define cudaStreamWaitEvent musaStreamWaitEvent
+#define cudaStream_t musaStream_t
+#define cudaSuccess musaSuccess
+
+// Additional mappings for MUSA virtual memory pool
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
+#define CUdevice MUdevice
+#define CUdeviceptr MUdeviceptr
+#define CUmemAccessDesc MUmemAccessDesc
+#define CUmemAllocationProp MUmemAllocationProp
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
+#define cuDeviceGet muDeviceGet
+#define cuDeviceGetAttribute muDeviceGetAttribute
+#define cuMemAddressFree muMemAddressFree
+#define cuMemAddressReserve muMemAddressReserve
+#define cuMemCreate muMemCreate
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
+#define cuMemMap muMemMap
+#define cuMemRelease muMemRelease
+#define cuMemSetAccess muMemSetAccess
+#define cuMemUnmap muMemUnmap
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
+#define cudaFuncSetAttribute musaFuncSetAttribute
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
+#define make_cudaExtent make_musaExtent
+#define make_cudaPitchedPtr make_musaPitchedPtr
+
+// Additional mappings for MUSA graphs
+#define CUDA_SUCCESS MUSA_SUCCESS
+#define CUresult MUresult
+#define cuGetErrorString muGetErrorString
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
+#define cudaGraphDestroy musaGraphDestroy
+#define cudaGraphExecDestroy musaGraphExecDestroy
+#define cudaGraphExec_t musaGraphExec_t
+#define cudaGraphExecUpdate musaGraphExecUpdate
+#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphGetNodes musaGraphGetNodes
+#define cudaGraphInstantiate musaGraphInstantiate
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
+#define cudaGraphLaunch musaGraphLaunch
+#define cudaGraphNodeGetType musaGraphNodeGetType
+#define cudaGraphNode_t musaGraphNode_t
+#define cudaGraphNodeType musaGraphNodeType
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
+#define cudaGraph_t musaGraph_t
+#define cudaKernelNodeParams musaKernelNodeParams
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamEndCapture musaStreamEndCapture
+
+typedef mt_bfloat16 nv_bfloat16;
diff --git a/llama/ggml-cuda/wkv6.cu b/llama/ggml-cuda/wkv6.cu
new file mode 100644
index 000000000..fe4e5b9d4
--- /dev/null
+++ b/llama/ggml-cuda/wkv6.cu
@@ -0,0 +1,115 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+#include "wkv6.cuh"
+
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = CUDA_WKV_BLOCK_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    __syncthreads();
+    _tf[tid] = tf[head_i * head_size + tid];
+    __syncthreads();
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& k = (float4&)(_k[j]);
+            const float4& r = (float4&)(_r[j]);
+            const float4& tf = (float4&)(_tf[j]);
+            const float4& td = (float4&)(_td[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            y += r.x * (tf.x * kv.x + s.x);
+            y += r.y * (tf.y * kv.y + s.y);
+            y += r.z * (tf.z * kv.z + s.z);
+            y += r.w * (tf.w * kv.w + s.w);
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * tf_d = (const float *)dst->src[3]->data;
+    const float * td_d = (const float *)dst->src[4]->data;
+    const float * s_d  = (const float *)dst->src[5]->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[2];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
+
+    rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+}
diff --git a/llama/ggml-cuda/wkv6.cuh b/llama/ggml-cuda/wkv6.cuh
new file mode 100644
index 000000000..270272878
--- /dev/null
+++ b/llama/ggml-cuda/wkv6.cuh
@@ -0,0 +1,31 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama/ggml-impl.h b/llama/ggml-impl.h
new file mode 100644
index 000000000..46760fb32
--- /dev/null
+++ b/llama/ggml-impl.h
@@ -0,0 +1,598 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+// GGML internal header
+
+#include "ggml.h"
+#include 
+#include 
+#include  // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
+#include 
+#include 
+#include 
+
+#ifdef __ARM_FEATURE_SVE
+#include 
+#endif // __ARM_FEATURE_SVE
+
+#if defined(__ARM_NEON) && !defined(__CUDACC__)
+// if YCM cannot find , make a symbolic link to it, for example:
+//
+//   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
+//
+#include 
+#endif
+
+#if defined(__F16C__)
+#include 
+#endif
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifndef MIN
+#    define MIN(a, b) ((a) < (b) ? (a) : (b))
+#endif
+
+#ifndef MAX
+#    define MAX(a, b) ((a) > (b) ? (a) : (b))
+#endif
+
+// required for mmap as gguf only guarantees 32-byte alignment
+#define TENSOR_ALIGNMENT 32
+
+// static_assert should be a #define, but if it's not,
+// fall back to the _Static_assert C11 keyword.
+// if C99 - static_assert is noop
+// ref: https://stackoverflow.com/a/53923785/4039976
+#ifndef __cplusplus
+    #ifndef static_assert
+        #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+            #define static_assert(cond, msg) _Static_assert(cond, msg)
+        #else
+            #define static_assert(cond, msg) struct global_scope_noop_trick
+        #endif
+    #endif
+#endif
+
+static inline int ggml_up32(int n) {
+    return (n + 31) & ~31;
+}
+
+//static inline int ggml_up64(int n) {
+//    return (n + 63) & ~63;
+//}
+
+static inline int ggml_up(int n, int m) {
+    // assert m is a power of 2
+    GGML_ASSERT((m & (m - 1)) == 0);
+    return (n + m - 1) & ~(m - 1);
+}
+
+//
+// logging
+//
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+GGML_API void ggml_log_internal        (enum ggml_log_level level, const char * format, ...);
+GGML_API void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data);
+
+#define GGML_LOG(...)       ggml_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)
+#define GGML_LOG_INFO(...)  ggml_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
+#define GGML_LOG_WARN(...)  ggml_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
+#define GGML_LOG_ERROR(...) ggml_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+#define GGML_LOG_DEBUG(...) ggml_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
+#define GGML_LOG_CONT(...)  ggml_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
+
+#define GGML_DEBUG 0
+
+#if (GGML_DEBUG >= 1)
+#define GGML_PRINT_DEBUG(...) GGML_LOG_DEBUG(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG(...)
+#endif
+
+#if (GGML_DEBUG >= 5)
+#define GGML_PRINT_DEBUG_5(...) GGML_LOG_DEBUG(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_5(...)
+#endif
+
+#if (GGML_DEBUG >= 10)
+#define GGML_PRINT_DEBUG_10(...) GGML_LOG_DEBUG(__VA_ARGS__)
+#else
+#define GGML_PRINT_DEBUG_10(...)
+#endif
+
+// tensor params
+
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
+    assert(params_size <= GGML_MAX_OP_PARAMS);
+    memcpy(tensor->op_params, params, params_size);
+}
+
+static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    return ((const int32_t *)(tensor->op_params))[i];
+}
+
+static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
+    return ((const float *)(tensor->op_params))[i];
+}
+
+static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    ((int32_t *)(tensor->op_params))[i] = value;
+}
+
+static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(float));
+    ((float *)(tensor->op_params))[i] = value;
+}
+
+struct ggml_map_custom1_op_params {
+    ggml_custom1_op_t  fun;
+    int                n_tasks;
+    void             * userdata;
+};
+
+struct ggml_map_custom2_op_params {
+    ggml_custom2_op_t   fun;
+    int                 n_tasks;
+    void              * userdata;
+};
+
+struct ggml_map_custom3_op_params {
+    ggml_custom3_op_t fun;
+    int n_tasks;
+    void * userdata;
+};
+
+// bitset
+
+typedef uint32_t ggml_bitset_t;
+
+static_assert(sizeof(ggml_bitset_t) == 4, "bitset_t constants must be updated");
+#define BITSET_SHR 5 // log2(sizeof(ggml_bitset_t)*8)
+#define BITSET_MASK (sizeof(ggml_bitset_t)*8 - 1)
+
+static size_t ggml_bitset_size(size_t n) {
+    return (n + BITSET_MASK) >> BITSET_SHR;
+}
+
+static inline bool ggml_bitset_get(const ggml_bitset_t * bitset, size_t i) {
+    return !!(bitset[i >> BITSET_SHR] & (1u << (i & BITSET_MASK)));
+}
+
+static inline void ggml_bitset_set(ggml_bitset_t * bitset, size_t i) {
+    bitset[i >> BITSET_SHR] |= (1u << (i & BITSET_MASK));
+}
+
+static inline void ggml_bitset_clear(ggml_bitset_t * bitset, size_t i) {
+    bitset[i >> BITSET_SHR] &= ~(1u << (i & BITSET_MASK));
+}
+
+// hash set
+
+#define GGML_HASHSET_FULL ((size_t)-1)
+#define GGML_HASHSET_ALREADY_EXISTS ((size_t)-2)
+
+struct ggml_hash_set {
+    size_t size;
+    ggml_bitset_t * used;       // whether or not the keys are in use i.e. set
+    struct ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if ggml_bitset_get(used, i)
+};
+
+struct ggml_hash_set ggml_hash_set_new(size_t size);
+void                 ggml_hash_set_free(struct ggml_hash_set * hash_set);
+
+// returns the minimum size for a hash set that can hold min_sz elements
+size_t ggml_hash_size(size_t min_sz);
+
+// remove all elements from the hash set
+void ggml_hash_set_reset(struct ggml_hash_set * hash_set);
+
+// returns true if key is in the hash set
+static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
+
+// returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted
+static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key);
+
+// returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
+static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);
+
+// return index, asserts if table is full
+static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);
+
+// hash function for ggml_tensor
+static inline size_t ggml_hash(const struct ggml_tensor * p) {
+    // the last 4 bits are always zero due to alignment
+    return (size_t)(uintptr_t)p >> 4;
+}
+
+static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) {
+    size_t h = ggml_hash(key) % hash_set->size;
+
+    // linear probing
+    size_t i = h;
+    while (ggml_bitset_get(hash_set->used, i) && hash_set->keys[i] != key) {
+        i = (i + 1) % hash_set->size;
+        if (i == h) {
+            // visited all hash table entries -> not found
+            return GGML_HASHSET_FULL;
+        }
+    }
+    return i;
+}
+
+static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) {
+    size_t i = ggml_hash_find(hash_set, key);
+    return i != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, i);
+}
+
+static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) {
+    size_t h = ggml_hash(key) % hash_set->size;
+
+    // linear probing
+    size_t i = h;
+    do {
+        if (!ggml_bitset_get(hash_set->used, i)) {
+            ggml_bitset_set(hash_set->used, i);
+            hash_set->keys[i] = key;
+            return i;
+        }
+        if (hash_set->keys[i] == key) {
+            return GGML_HASHSET_ALREADY_EXISTS;
+        }
+        i = (i + 1) % hash_set->size;
+    } while (i != h);
+
+    // visited all hash table entries -> not found
+    GGML_ABORT("fatal error");
+}
+
+static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) {
+    size_t h = ggml_hash(key) % hash_set->size;
+
+    // linear probing
+    size_t i = h;
+    do {
+        if (!ggml_bitset_get(hash_set->used, i)) {
+            ggml_bitset_set(hash_set->used, i);
+            hash_set->keys[i] = key;
+            return i;
+        }
+        if (hash_set->keys[i] == key) {
+            return i;
+        }
+        i = (i + 1) % hash_set->size;
+    } while (i != h);
+
+    // visited all hash table entries -> not found
+    GGML_ABORT("fatal error");
+}
+
+// computation graph
+
+enum ggml_cgraph_eval_order {
+    GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0,
+    GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT,
+    GGML_CGRAPH_EVAL_ORDER_COUNT
+};
+
+struct ggml_cgraph {
+    int size;    // maximum number of nodes/leafs/grads/grad_accs
+    int n_nodes; // number of nodes currently in use
+    int n_leafs; // number of leafs currently in use
+
+    struct ggml_tensor ** nodes;     // tensors with data that can change if the graph is evaluated
+    struct ggml_tensor ** grads;     // the outputs of these tensors are the gradients of the nodes
+    struct ggml_tensor ** grad_accs; // accumulators for node gradients
+    struct ggml_tensor ** leafs;     // tensors with constant data
+
+    struct ggml_hash_set visited_hash_set;
+
+    enum ggml_cgraph_eval_order order;
+};
+
+// returns a slice of cgraph with nodes [i0, i1)
+// the slice does not have leafs or gradients
+// if you need the gradients, get them from the original graph
+struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
+
+// Memory allocation
+
+GGML_API void * ggml_aligned_malloc(size_t size);
+GGML_API void ggml_aligned_free(void * ptr, size_t size);
+
+// FP16 to FP32 conversion
+
+#if defined(__ARM_NEON)
+    #if defined(_MSC_VER) || (defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11)
+        typedef uint16_t ggml_fp16_internal_t;
+    #else
+        typedef __fp16 ggml_fp16_internal_t;
+    #endif
+#endif
+
+#if defined(__ARM_NEON) && !defined(_MSC_VER) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11)
+    #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+    #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+    #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+
+    static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+        ggml_fp16_internal_t tmp;
+        memcpy(&tmp, &h, sizeof(ggml_fp16_t));
+        return (float)tmp;
+    }
+
+    static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+        ggml_fp16_t res;
+        ggml_fp16_internal_t tmp = f;
+        memcpy(&res, &tmp, sizeof(ggml_fp16_t));
+        return res;
+    }
+
+#elif defined(__F16C__)
+
+    #ifdef _MSC_VER
+        #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
+        #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
+    #else
+        #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
+        #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
+    #endif
+
+#elif defined(__POWER9_VECTOR__)
+
+    #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+    #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+    /* the inline asm below is about 12% faster than the lookup method */
+    #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
+    #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+
+    static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+        register float f;
+        register double d;
+        __asm__(
+            "mtfprd %0,%2\n"
+            "xscvhpdp %0,%0\n"
+            "frsp %1,%0\n" :
+            /* temp */ "=d"(d),
+            /* out */  "=f"(f):
+            /* in */   "r"(h));
+        return f;
+    }
+
+    static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+        register double d;
+        register ggml_fp16_t r;
+        __asm__( /* xscvdphp can work on double or single precision */
+            "xscvdphp %0,%2\n"
+            "mffprd %1,%0\n" :
+            /* temp */ "=d"(d),
+            /* out */  "=r"(r):
+            /* in */   "f"(f));
+        return r;
+    }
+
+#else
+
+    // FP16 <-> FP32
+    // ref: https://github.com/Maratyszcza/FP16
+
+    static inline float fp32_from_bits(uint32_t w) {
+        union {
+            uint32_t as_bits;
+            float as_value;
+        } fp32;
+        fp32.as_bits = w;
+        return fp32.as_value;
+    }
+
+    static inline uint32_t fp32_to_bits(float f) {
+        union {
+            float as_value;
+            uint32_t as_bits;
+        } fp32;
+        fp32.as_value = f;
+        return fp32.as_bits;
+    }
+
+    static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
+        const uint32_t w = (uint32_t) h << 16;
+        const uint32_t sign = w & UINT32_C(0x80000000);
+        const uint32_t two_w = w + w;
+
+        const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+    #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
+        const float exp_scale = 0x1.0p-112f;
+    #else
+        const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
+    #endif
+        const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+        const uint32_t magic_mask = UINT32_C(126) << 23;
+        const float magic_bias = 0.5f;
+        const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+        const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+        const uint32_t result = sign |
+            (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+        return fp32_from_bits(result);
+    }
+
+    static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
+    #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
+        const float scale_to_inf = 0x1.0p+112f;
+        const float scale_to_zero = 0x1.0p-110f;
+    #else
+        const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
+        const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
+    #endif
+        float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+        const uint32_t w = fp32_to_bits(f);
+        const uint32_t shl1_w = w + w;
+        const uint32_t sign = w & UINT32_C(0x80000000);
+        uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+        if (bias < UINT32_C(0x71000000)) {
+            bias = UINT32_C(0x71000000);
+        }
+
+        base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+        const uint32_t bits = fp32_to_bits(base);
+        const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+        const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+        const uint32_t nonsign = exp_bits + mantissa_bits;
+        return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+    }
+
+    #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
+    #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
+
+#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
+
+// precomputed f32 table for f16 (256 KB)
+// defined in ggml.c, initialized in ggml_init()
+GGML_API float ggml_table_f32_f16[1 << 16];
+
+// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
+// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
+// This is also true for POWER9.
+#if !defined(GGML_FP16_TO_FP32)
+inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
+    uint16_t s;
+    memcpy(&s, &f, sizeof(uint16_t));
+    return ggml_table_f32_f16[s];
+}
+
+#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
+#endif
+
+#if !defined(GGML_FP32_TO_FP16)
+#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+#endif
+
+/**
+ * Converts brain16 to float32.
+ *
+ * The bfloat16 floating point format has the following structure:
+ *
+ *       ┌sign
+ *       │
+ *       │   ┌exponent
+ *       │   │
+ *       │   │      ┌mantissa
+ *       │   │      │
+ *       │┌──┴───┐┌─┴───┐
+ *     0b0000000000000000 brain16
+ *
+ * Since bf16 has the same number of exponent bits as a 32bit float,
+ * encoding and decoding numbers becomes relatively straightforward.
+ *
+ *       ┌sign
+ *       │
+ *       │   ┌exponent
+ *       │   │
+ *       │   │      ┌mantissa
+ *       │   │      │
+ *       │┌──┴───┐┌─┴───────────────────┐
+ *     0b00000000000000000000000000000000 IEEE binary32
+ *
+ * For comparison, the standard fp16 format has fewer exponent bits.
+ *
+ *       ┌sign
+ *       │
+ *       │  ┌exponent
+ *       │  │
+ *       │  │    ┌mantissa
+ *       │  │    │
+ *       │┌─┴─┐┌─┴──────┐
+ *     0b0000000000000000 IEEE binary16
+ *
+ * @see IEEE 754-2008
+ */
+static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
+    union {
+        float f;
+        uint32_t i;
+    } u;
+    u.i = (uint32_t)h.bits << 16;
+    return u.f;
+}
+
+/**
+ * Converts float32 to brain16.
+ *
+ * This is binary identical with Google Brain float conversion.
+ * Floats shall round to nearest even, and NANs shall be quiet.
+ * Subnormals aren't flushed to zero, except perhaps when used.
+ * This code should vectorize nicely if using modern compilers.
+ */
+static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
+    ggml_bf16_t h;
+    union {
+        float f;
+        uint32_t i;
+    } u;
+    u.f = s;
+    if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
+        h.bits = (u.i >> 16) | 64; /* force to quiet */
+        return h;
+    }
+    h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
+    return h;
+}
+
+#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
+#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
+
+// expose GGUF internals for test code
+
+GGML_API size_t gguf_type_size(enum gguf_type type);
+
+GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+
+struct gguf_buf {
+    void * data;
+    size_t size;
+    size_t offset;
+};
+GGML_API struct gguf_buf gguf_buf_init(size_t size);
+GGML_API void gguf_buf_free(struct gguf_buf buf);
+
+GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml-metal-embed.metal b/llama/ggml-metal-embed.metal
new file mode 100644
index 000000000..7f4666c93
--- /dev/null
+++ b/llama/ggml-metal-embed.metal
@@ -0,0 +1,8994 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_DECL_METAL
+#define GGML_COMMON_IMPL_METAL
+#if defined(GGML_METAL_EMBED_LIBRARY)
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef GGML_COMMON_DECL
+
+#if defined(GGML_COMMON_DECL_C)
+#include 
+
+typedef uint16_t ggml_half;
+typedef uint32_t ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_CPP)
+#include 
+
+typedef uint16_t ggml_half;
+typedef uint32_t ggml_half2;
+
+// std-c++ allow anonymous unions but some compiler warn on it
+#define GGML_COMMON_AGGR_U data
+// std-c++ do not allow it.
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_METAL)
+#include 
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_CUDA)
+#if defined(GGML_COMMON_DECL_MUSA)
+#include 
+#else
+#include 
+#endif
+#include 
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_HIP)
+#include 
+#include 
+
+typedef half  ggml_half;
+typedef half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#elif defined(GGML_COMMON_DECL_SYCL)
+#include 
+#include 
+
+typedef sycl::half  ggml_half;
+typedef sycl::half2 ggml_half2;
+
+#define GGML_COMMON_AGGR_U
+#define GGML_COMMON_AGGR_S data
+
+#define GGML_COMMON_DECL
+#endif
+
+#if defined(GGML_COMMON_DECL)
+
+#ifndef __cplusplus
+#ifndef static_assert
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
+#define static_assert(cond, msg) _Static_assert(cond, msg)
+#else
+#define static_assert(cond, msg) struct global_scope_noop_trick
+#endif
+#endif
+#endif // __cplusplus
+
+// QK = number of values after dequantization
+// QK_K = super-block size
+
+#define QK_K 256
+#define K_SCALE_SIZE 12
+
+#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
+// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
+
+#define QI4_0 (QK4_0 / (4 * QR4_0))
+#define QR4_0 2
+
+#define QI4_1 (QK4_1 / (4 * QR4_1))
+#define QR4_1 2
+
+#define QI5_0 (QK5_0 / (4 * QR5_0))
+#define QR5_0 2
+
+#define QI5_1 (QK5_1 / (4 * QR5_1))
+#define QR5_1 2
+
+#define QI8_0 (QK8_0 / (4 * QR8_0))
+#define QR8_0 1
+
+#define QI8_1 (QK8_1 / (4 * QR8_1))
+#define QR8_1 1
+
+#define QI2_K (QK_K / (4*QR2_K))
+#define QR2_K 4
+
+#define QI3_K (QK_K / (4*QR3_K))
+#define QR3_K 4
+
+#define QI4_K (QK_K / (4*QR4_K))
+#define QR4_K 2
+
+#define QI5_K (QK_K / (4*QR5_K))
+#define QR5_K 2
+
+#define QI6_K (QK_K / (4*QR6_K))
+#define QR6_K 2
+
+#define QI2_XXS (QK_K / (4*QR2_XXS))
+#define QR2_XXS 4
+
+#define QI2_XS (QK_K / (4*QR2_XS))
+#define QR2_XS 4
+
+#define QI2_S (QK_K / (4*QR2_S))
+#define QR2_S 4
+
+#define QI3_XXS (QK_K / (4*QR3_XXS))
+#define QR3_XXS 4
+
+#define QI3_XS (QK_K / (4*QR3_XS))
+#define QR3_XS 4
+
+#define QI1_S (QK_K / (4*QR1_S))
+#define QR1_S 8
+
+#define QI1_M (QK_K / (4*QR1_M))
+#define QR1_M 8
+
+#define QI4_NL (QK4_NL / (4*QR4_NL))
+#define QR4_NL 2
+
+#define QI4_XS (QK_K / (4*QR4_XS))
+#define QR4_XS 2
+
+#define QI3_S (QK_K / (4*QR3_S))
+#define QR3_S 4
+
+#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
+
+#define QK4_0 32
+typedef struct {
+    ggml_half d;           // delta
+    uint8_t qs[QK4_0 / 2]; // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_half) + QK4_0 / 2, "wrong q4_0 block size/padding");
+
+#define QK4_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half m; // min
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t qs[QK4_1 / 2]; // nibbles / quants
+} block_q4_1;
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+
+#define QK5_0 32
+typedef struct {
+    ggml_half d;           // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+static_assert(sizeof(block_q5_0) == sizeof(ggml_half) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
+
+#define QK5_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half m; // min
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
+
+#define QK8_0 32
+typedef struct {
+    ggml_half d;       // delta
+    int8_t  qs[QK8_0]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(ggml_half) + QK8_0, "wrong q8_0 block size/padding");
+
+#define QK8_1 32
+typedef struct {
+    union {
+        struct {
+            ggml_half d; // delta
+            ggml_half s; // d * sum(qs[i])
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 ds;
+    } GGML_COMMON_AGGR_U;
+    int8_t qs[QK8_1]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
+
+//
+// Ternary quantization
+//
+
+// 1.6875 bpw
+typedef struct {
+    uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256)
+    uint8_t qh[QK_K/64]; // 4 elements per byte
+    ggml_half d;
+} block_tq1_0;
+static_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding");
+
+// 2.0625 bpw
+typedef struct {
+    uint8_t qs[QK_K/4]; // 2 bits per element
+    ggml_half d;
+} block_tq2_0;
+static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");
+
+//
+// Super-block quantization structures
+//
+
+// 2-bit quantization
+// weight is represented as x = a * q + b
+// 16 blocks of 16 elements each
+// Effectively 2.625 bits per weight
+typedef struct {
+    uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
+    uint8_t qs[QK_K/4];      // quants
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+} block_q2_K;
+static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
+
+// 3-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 3.4375 bits per weight
+typedef struct {
+    uint8_t hmask[QK_K/8]; // quants - high bit
+    uint8_t qs[QK_K/4];    // quants - low 2 bits
+    uint8_t scales[12];    // scales, quantized with 6 bits
+    ggml_half d;           // super-block scale
+} block_q3_K;
+static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
+
+// 4-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 4.5 bits per weight
+typedef struct {
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qs[QK_K/2];           // 4--bit quants
+} block_q4_K;
+static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
+
+// 5-bit quantization
+// 8 blocks of 32 elements each
+// weight is represented as x = a * q + b
+// Effectively 5.5 bits per weight
+typedef struct {
+    union {
+        struct {
+            ggml_half d;    // super-block scale for quantized scales
+            ggml_half dmin; // super-block scale for quantized mins
+        } GGML_COMMON_AGGR_S;
+        ggml_half2 dm;
+    } GGML_COMMON_AGGR_U;
+    uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
+    uint8_t qh[QK_K/8];           // quants, high bit
+    uint8_t qs[QK_K/2];           // quants, low 4 bits
+} block_q5_K;
+static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
+
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    ggml_half d;             // super-block scale
+} block_q6_K;
+static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
+
+// This is only used for intermediate quantization and dot products
+typedef struct {
+    float   d;              // delta
+    int8_t  qs[QK_K];       // quants
+    int16_t bsums[QK_K/16]; // sum of quants in groups of 16
+} block_q8_K;
+static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
+
+// (Almost) "true" 2-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 2.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_half d;
+    uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
+
+// 2.3125 bpw quants
+typedef struct {
+    ggml_half d;
+    uint16_t qs[QK_K/8];
+    uint8_t  scales[QK_K/32];
+} block_iq2_xs;
+static_assert(sizeof(block_iq2_xs) == sizeof(ggml_half) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
+
+// 2.5625 bpw quants
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t scales[QK_K/32];
+} block_iq2_s;
+static_assert(sizeof(block_iq2_s) == sizeof(ggml_half) + QK_K/4 + QK_K/16, "wrong iq2_s block size/padding");
+
+// (Almost) "true" 3-bit quantization.
+// Due to the need to use blocks as per ggml design, it ends up using
+// 3.0625 bpw because of the 16-bit scale for each block of 256.
+typedef struct {
+    ggml_half d;
+    uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
+
+// 3.4375 bpw
+#define IQ3S_N_SCALE QK_K/64
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK_K/4];
+    uint8_t qh[QK_K/32];
+    uint8_t signs[QK_K/8];
+    uint8_t scales[IQ3S_N_SCALE];
+} block_iq3_s;
+static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
+
+// 1.5625 bpw
+typedef struct {
+    ggml_half d;
+    uint8_t  qs[QK_K/8];
+    uint16_t qh[QK_K/32];
+} block_iq1_s;
+static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
+
+// 1.75 bpw
+typedef struct {
+    uint8_t  qs[QK_K/8];      // grid index, low 8 bits
+    uint8_t  qh[QK_K/16];     // grid index, high 3 bits + grid shift bit (for two groups of 8)
+    uint8_t  scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64)
+} block_iq1_m;
+static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding");
+
+// Used by IQ1_M quants
+typedef union {
+    ggml_half f16;
+    uint16_t  u16;
+} iq1m_scale_t;
+
+// Non-linear quants
+#define QK4_NL 32
+typedef struct {
+    ggml_half d;
+    uint8_t qs[QK4_NL/2];
+} block_iq4_nl;
+static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding");
+
+typedef struct {
+    ggml_half d;
+    uint16_t scales_h;
+    uint8_t  scales_l[QK_K/64];
+    uint8_t  qs[QK_K/2];
+} block_iq4_xs;
+static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
+
+#endif // GGML_COMMON_DECL
+#endif // GGML_COMMON_DECL
+
+////////////////////////////////////////////////////////////////////////////////
+
+#ifndef GGML_COMMON_IMPL
+
+#if defined(GGML_COMMON_IMPL_C)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_CPP)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_METAL)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const constant type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#elif defined(GGML_COMMON_IMPL_SYCL)
+
+#include 
+
+#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
+#define GGML_TABLE_END() };
+
+#define GGML_COMMON_IMPL
+#endif
+
+#if defined(GGML_COMMON_IMPL)
+
+GGML_TABLE_BEGIN(uint8_t, kmask_iq2xs, 8)
+    1, 2, 4, 8, 16, 32, 64, 128
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
+      0, 129, 130,   3, 132,   5,   6, 135, 136,   9,  10, 139,  12, 141, 142,  15,
+    144,  17,  18, 147,  20, 149, 150,  23,  24, 153, 154,  27, 156,  29,  30, 159,
+    160,  33,  34, 163,  36, 165, 166,  39,  40, 169, 170,  43, 172,  45,  46, 175,
+     48, 177, 178,  51, 180,  53,  54, 183, 184,  57,  58, 187,  60, 189, 190,  63,
+    192,  65,  66, 195,  68, 197, 198,  71,  72, 201, 202,  75, 204,  77,  78, 207,
+     80, 209, 210,  83, 212,  85,  86, 215, 216,  89,  90, 219,  92, 221, 222,  95,
+     96, 225, 226,  99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+    240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+GGML_TABLE_END()
+
+//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
+GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
+    0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
+    0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
+    0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff,
+    0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff,
+    0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff,
+    0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff,
+    0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff,
+    0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff,
+    0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff,
+    0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff,
+    0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff,
+    0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff,
+    0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff,
+    0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff,
+    0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff,
+    0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff,
+    0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff,
+    0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff,
+    0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff,
+    0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff,
+    0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff,
+    0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff,
+    0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff,
+    0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff,
+    0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff,
+    0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff,
+    0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff,
+    0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff,
+    0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff,
+    0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff,
+    0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
+    0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
+GGML_TABLE_END()
+//#endif
+
+
+GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+    0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+    0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+    0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+    0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+    0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+    0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+    0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+    0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+    0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+    0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+    0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+    0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+    0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+    0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+    0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+    0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+    0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+    0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+    0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+    0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+    0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+    0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+    0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+    0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+    0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+    0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+    0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+    0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+    0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+    0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+    0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+    0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+    0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+    0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+    0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+    0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+    0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+    0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+    0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+    0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+    0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+    0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+    0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+    0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+    0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+    0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+    0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+    0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+    0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+    0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+    0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+    0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+    0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+    0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+    0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+    0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+    0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+    0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2xs_grid, 512)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
+    0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
+    0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
+    0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
+    0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
+    0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
+    0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
+    0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
+    0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
+    0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
+    0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
+    0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
+    0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
+    0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
+    0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
+    0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
+    0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
+    0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
+    0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
+    0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
+    0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
+    0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
+    0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
+    0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
+    0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
+    0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
+    0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
+    0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
+    0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
+    0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
+    0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
+    0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
+    0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
+    0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
+    0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
+    0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
+    0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
+    0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
+    0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
+    0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
+    0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
+    0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
+    0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
+    0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
+    0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
+    0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
+    0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
+    0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
+    0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
+    0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
+    0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
+    0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
+    0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
+    0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
+    0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
+    0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
+    0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
+    0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
+    0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
+    0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
+    0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
+    0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
+    0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
+    0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
+    0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
+    0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
+    0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
+    0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
+    0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
+    0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
+    0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
+    0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
+    0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
+    0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
+    0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
+    0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
+    0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
+    0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
+    0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
+    0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
+    0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
+    0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
+    0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
+    0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
+    0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
+    0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
+    0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
+    0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
+    0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
+    0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
+    0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
+    0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
+    0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
+    0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
+    0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
+    0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
+    0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
+    0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
+    0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
+    0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
+    0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
+    0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
+    0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
+    0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
+    0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
+    0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
+    0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
+    0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
+    0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
+    0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
+    0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
+    0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
+    0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
+    0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
+    0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
+    0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
+    0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint64_t, iq2s_grid, 1024)
+    0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+    0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+    0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+    0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+    0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+    0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
+    0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
+    0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
+    0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
+    0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
+    0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
+    0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
+    0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
+    0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
+    0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
+    0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
+    0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
+    0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
+    0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
+    0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
+    0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
+    0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
+    0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
+    0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
+    0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
+    0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
+    0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
+    0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
+    0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
+    0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
+    0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
+    0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
+    0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
+    0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
+    0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
+    0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
+    0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
+    0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
+    0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
+    0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
+    0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
+    0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
+    0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
+    0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
+    0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
+    0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
+    0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
+    0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
+    0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
+    0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
+    0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
+    0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
+    0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
+    0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
+    0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
+    0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
+    0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
+    0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
+    0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
+    0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
+    0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
+    0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
+    0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
+    0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
+    0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
+    0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
+    0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
+    0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
+    0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
+    0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
+    0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
+    0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
+    0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
+    0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
+    0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
+    0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
+    0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
+    0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
+    0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
+    0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
+    0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
+    0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
+    0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
+    0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
+    0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
+    0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
+    0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
+    0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
+    0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
+    0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
+    0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
+    0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
+    0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
+    0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
+    0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
+    0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
+    0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
+    0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
+    0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
+    0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
+    0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
+    0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
+    0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
+    0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
+    0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
+    0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
+    0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
+    0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
+    0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
+    0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
+    0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
+    0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
+    0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
+    0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
+    0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
+    0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
+    0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
+    0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
+    0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
+    0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
+    0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
+    0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
+    0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
+    0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
+    0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
+    0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
+    0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
+    0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
+    0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
+    0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
+    0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
+    0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
+    0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
+    0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
+    0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
+    0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
+    0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
+    0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
+    0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
+    0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
+    0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
+    0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
+    0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
+    0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
+    0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
+    0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
+    0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
+    0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
+    0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
+    0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
+    0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
+    0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
+    0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
+    0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
+    0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
+    0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
+    0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
+    0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
+    0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
+    0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
+    0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
+    0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
+    0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
+    0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
+    0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
+    0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
+    0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
+    0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
+    0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
+    0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
+    0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
+    0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
+    0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
+    0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
+    0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
+    0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
+    0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
+    0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
+    0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
+    0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
+    0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
+    0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
+    0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
+    0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
+    0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
+    0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
+    0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
+    0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
+    0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
+    0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
+    0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
+    0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
+    0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
+    0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
+    0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
+    0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
+    0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
+    0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
+    0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
+    0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
+    0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
+    0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
+    0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
+    0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
+    0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
+    0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
+    0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
+    0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
+    0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
+    0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
+    0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
+    0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
+    0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
+    0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
+    0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
+    0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
+    0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
+    0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
+    0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
+    0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
+    0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
+    0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
+    0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
+    0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
+    0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
+    0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
+    0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
+    0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
+    0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
+    0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
+    0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
+    0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
+    0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
+    0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
+    0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
+    0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
+    0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
+    0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
+    0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
+    0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
+    0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
+    0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
+    0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
+    0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
+    0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
+    0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
+    0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
+    0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
+    0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
+    0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
+    0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
+    0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
+    0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
+    0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
+    0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
+    0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3xxs_grid, 256)
+    0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+    0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+    0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+    0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+    0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+    0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+    0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+    0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+    0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+    0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+    0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+    0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+    0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+    0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+    0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+    0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+    0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+    0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+    0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+    0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+    0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+    0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+    0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+    0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+    0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+    0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+    0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+    0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+    0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+    0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+    0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+    0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+GGML_TABLE_END()
+
+GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
+    0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
+    0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
+    0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
+    0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
+    0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
+    0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
+    0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
+    0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
+    0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
+    0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
+    0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
+    0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
+    0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
+    0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
+    0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
+    0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
+    0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
+    0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
+    0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
+    0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
+    0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
+    0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
+    0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
+    0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
+    0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
+    0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
+    0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
+    0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
+    0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
+    0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
+    0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
+    0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
+    0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
+    0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
+    0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
+    0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
+    0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
+    0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
+    0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
+    0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
+    0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
+    0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
+    0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
+    0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
+    0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
+    0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
+    0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
+    0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
+    0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
+    0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
+    0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
+    0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
+    0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
+    0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
+    0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
+    0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
+    0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
+    0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
+    0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
+    0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
+    0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
+    0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
+    0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
+    0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
+GGML_TABLE_END()
+
+#define NGRID_IQ1S 2048
+#define IQ1S_DELTA 0.125f
+#define IQ1M_DELTA 0.125f
+#if defined(GGML_COMMON_IMPL_C)
+GGML_TABLE_BEGIN(uint64_t, iq1s_grid, NGRID_IQ1S)
+    0xffffffffffffffff, 0xffffffffffffff01, 0xffffffffffff0000, 0xffffffffffff01ff,
+    0xffffffffffff0101, 0xffffffffff00ff00, 0xffffffffff000000, 0xffffffffff01ffff,
+    0xffffffffff01ff01, 0xffffffffff0101ff, 0xffffffffff010101, 0xffffffff00ff0000,
+    0xffffffff0000ff00, 0xffffffff000000ff, 0xffffffff00000001, 0xffffffff00010000,
+    0xffffffff01ffffff, 0xffffffff01ffff01, 0xffffffff01ff01ff, 0xffffffff01ff0101,
+    0xffffffff01000000, 0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff010101ff,
+    0xffffffff01010101, 0xffffff00ffff00ff, 0xffffff00ffff0000, 0xffffff00ff00ff00,
+    0xffffff00ff0000ff, 0xffffff00ff000001, 0xffffff00ff000100, 0xffffff00ff000101,
+    0xffffff00ff010000, 0xffffff0000ffff00, 0xffffff0000ff0001, 0xffffff0000ff0100,
+    0xffffff000000ff01, 0xffffff0000000000, 0xffffff0000000101, 0xffffff000001ff00,
+    0xffffff00000100ff, 0xffffff0000010001, 0xffffff00000101ff, 0xffffff0001ff0000,
+    0xffffff000100ff00, 0xffffff00010000ff, 0xffffff0001000001, 0xffffff0001010000,
+    0xffffff01ffffffff, 0xffffff01ffffff01, 0xffffff01ffff01ff, 0xffffff01ffff0101,
+    0xffffff01ff000000, 0xffffff01ff01ffff, 0xffffff01ff01ff01, 0xffffff01ff0101ff,
+    0xffffff01ff010101, 0xffffff0100ff0000, 0xffffff010000ff00, 0xffffff0100000100,
+    0xffffff01000100ff, 0xffffff0100010100, 0xffffff0101ffffff, 0xffffff0101ffff01,
+    0xffffff0101ff01ff, 0xffffff0101ff0101, 0xffffff010100ff00, 0xffffff0101000000,
+    0xffffff0101000100, 0xffffff010101ffff, 0xffffff010101ff01, 0xffffff01010101ff,
+    0xffffff0101010101, 0xffff00ffff00ff00, 0xffff00ffff0000ff, 0xffff00ffff000001,
+    0xffff00ffff010000, 0xffff00ff00ffff00, 0xffff00ff00ff0100, 0xffff00ff00000000,
+    0xffff00ff00000101, 0xffff00ff000100ff, 0xffff00ff00010000, 0xffff00ff0100ff00,
+    0xffff00ff01000100, 0xffff00ff01010000, 0xffff0000ffffff00, 0xffff0000ffff00ff,
+    0xffff0000ffff0000, 0xffff0000ffff0001, 0xffff0000ff000000, 0xffff0000ff0001ff,
+    0xffff0000ff000101, 0xffff0000ff010100, 0xffff000000ffffff, 0xffff000000ff0000,
+    0xffff000000ff0101, 0xffff00000000ffff, 0xffff00000000ff00, 0xffff0000000000ff,
+    0xffff000000000000, 0xffff000000000001, 0xffff000000000100, 0xffff00000001ffff,
+    0xffff00000001ff01, 0xffff000000010000, 0xffff0000000101ff, 0xffff000000010101,
+    0xffff000001ffff00, 0xffff00000100ff00, 0xffff000001000000, 0xffff0000010001ff,
+    0xffff000001000101, 0xffff00000101ff00, 0xffff0000010100ff, 0xffff000001010000,
+    0xffff000001010001, 0xffff000001010100, 0xffff0001ff0000ff, 0xffff0001ff000100,
+    0xffff000100ffff00, 0xffff000100ff00ff, 0xffff00010000ffff, 0xffff00010000ff01,
+    0xffff000100000000, 0xffff0001000001ff, 0xffff00010001ffff, 0xffff00010001ff00,
+    0xffff000100010001, 0xffff000100010100, 0xffff000101ff0000, 0xffff00010100ff00,
+    0xffff0001010000ff, 0xffff000101000100, 0xffff01ffffffffff, 0xffff01ffffffff01,
+    0xffff01ffffff01ff, 0xffff01ffffff0101, 0xffff01ffff000000, 0xffff01ffff01ffff,
+    0xffff01ffff01ff01, 0xffff01ffff0101ff, 0xffff01ffff010101, 0xffff01ff00ff0000,
+    0xffff01ff0000ff00, 0xffff01ff00000001, 0xffff01ff00010000, 0xffff01ff01ffffff,
+    0xffff01ff01ffff01, 0xffff01ff01ff01ff, 0xffff01ff01ff0101, 0xffff01ff01000000,
+    0xffff01ff0101ffff, 0xffff01ff0101ff01, 0xffff01ff010101ff, 0xffff01ff01010101,
+    0xffff0100ffff0000, 0xffff0100ff00ff00, 0xffff0100ff0000ff, 0xffff0100ff000100,
+    0xffff0100ff0100ff, 0xffff0100ff010000, 0xffff010000ffff00, 0xffff01000000ffff,
+    0xffff01000000ff00, 0xffff010000000000, 0xffff01000001ff00, 0xffff0100000100ff,
+    0xffff010000010100, 0xffff01000100ff00, 0xffff0100010000ff, 0xffff010001000001,
+    0xffff010001000100, 0xffff010001010000, 0xffff0101ffffffff, 0xffff0101ffffff01,
+    0xffff0101ffff01ff, 0xffff0101ffff0101, 0xffff0101ff000000, 0xffff0101ff01ffff,
+    0xffff0101ff01ff01, 0xffff0101ff0101ff, 0xffff0101ff010101, 0xffff010100ff0000,
+    0xffff01010000ff00, 0xffff010100000100, 0xffff01010001ff00, 0xffff010100010000,
+    0xffff010101ffffff, 0xffff010101ffff01, 0xffff010101ff0000, 0xffff010101ff01ff,
+    0xffff010101ff0101, 0xffff010101000000, 0xffff01010101ffff, 0xffff01010101ff01,
+    0xffff0101010101ff, 0xffff010101010101, 0xff00ffffff00ffff, 0xff00ffffff00ff00,
+    0xff00ffffff0000ff, 0xff00ffffff000100, 0xff00ffffff0100ff, 0xff00ffffff010000,
+    0xff00ffff00ffff00, 0xff00ffff00ff00ff, 0xff00ffff0000ffff, 0xff00ffff00000000,
+    0xff00ffff000001ff, 0xff00ffff0001ff00, 0xff00ffff000100ff, 0xff00ffff00010000,
+    0xff00ffff00010100, 0xff00ffff0100ff00, 0xff00ffff010000ff, 0xff00ffff01000001,
+    0xff00ffff0101ff00, 0xff00ffff01010000, 0xff00ff00ffffff00, 0xff00ff00ffff00ff,
+    0xff00ff00ffff0001, 0xff00ff00ffff0100, 0xff00ff00ff00ffff, 0xff00ff00ff00ff01,
+    0xff00ff00ff000000, 0xff00ff00ff0001ff, 0xff00ff00ff01ff00, 0xff00ff00ff0100ff,
+    0xff00ff00ff010100, 0xff00ff0000ff0000, 0xff00ff0000ff0101, 0xff00ff000000ffff,
+    0xff00ff000000ff00, 0xff00ff000000ff01, 0xff00ff00000000ff, 0xff00ff0000000000,
+    0xff00ff0000000001, 0xff00ff0000000100, 0xff00ff000001ffff, 0xff00ff0000010000,
+    0xff00ff0001ff00ff, 0xff00ff000100ff01, 0xff00ff0001000000, 0xff00ff000101ff00,
+    0xff00ff00010100ff, 0xff00ff01ff00ff00, 0xff00ff01ff0000ff, 0xff00ff01ff000001,
+    0xff00ff01ff010000, 0xff00ff0100ffffff, 0xff00ff0100ff0001, 0xff00ff0100ff0100,
+    0xff00ff010000ff01, 0xff00ff0100000000, 0xff00ff01000001ff, 0xff00ff0100000101,
+    0xff00ff01000100ff, 0xff00ff0100010001, 0xff00ff0101ff0000, 0xff00ff010100ff00,
+    0xff00ff01010000ff, 0xff00ff0101000001, 0xff00ff0101010000, 0xff0000ffffffff00,
+    0xff0000ffffff0001, 0xff0000ffffff0100, 0xff0000ffff0000ff, 0xff0000ffff000000,
+    0xff0000ffff0001ff, 0xff0000ffff000100, 0xff0000ffff01ff00, 0xff0000ffff010001,
+    0xff0000ff00ffff00, 0xff0000ff00ff0000, 0xff0000ff00ff0001, 0xff0000ff00ff01ff,
+    0xff0000ff00ff0101, 0xff0000ff0000ff00, 0xff0000ff000000ff, 0xff0000ff00000000,
+    0xff0000ff00000001, 0xff0000ff00000100, 0xff0000ff0001ff01, 0xff0000ff00010000,
+    0xff0000ff000101ff, 0xff0000ff01ff00ff, 0xff0000ff01ff0100, 0xff0000ff0100ffff,
+    0xff0000ff010000ff, 0xff0000ff01000000, 0xff0000ff010001ff, 0xff0000ff01000100,
+    0xff0000ff01000101, 0xff0000ff0101ff00, 0xff0000ff010100ff, 0xff0000ff01010000,
+    0xff0000ff01010100, 0xff000000ffffff01, 0xff000000ffff0000, 0xff000000ffff0101,
+    0xff000000ff00ff00, 0xff000000ff0000ff, 0xff000000ff000000, 0xff000000ff000001,
+    0xff000000ff000100, 0xff000000ff01ffff, 0xff000000ff01ff01, 0xff000000ff010000,
+    0xff000000ff0101ff, 0xff000000ff010101, 0xff00000000ffff00, 0xff00000000ff00ff,
+    0xff00000000ff0000, 0xff00000000ff0001, 0xff0000000000ff00, 0xff0000000000ff01,
+    0xff000000000000ff, 0xff00000000000000, 0xff00000000000001, 0xff00000000000100,
+    0xff00000000000101, 0xff0000000001ff00, 0xff000000000100ff, 0xff00000000010000,
+    0xff00000000010001, 0xff00000000010100, 0xff00000001ffffff, 0xff00000001ffff01,
+    0xff00000001ff00ff, 0xff00000001ff0000, 0xff00000001ff01ff, 0xff00000001ff0101,
+    0xff0000000100ffff, 0xff0000000100ff00, 0xff000000010000ff, 0xff00000001000000,
+    0xff00000001000001, 0xff00000001000100, 0xff00000001000101, 0xff0000000101ffff,
+    0xff0000000101ff01, 0xff00000001010000, 0xff000001ffffff00, 0xff000001ffff00ff,
+    0xff000001ffff0000, 0xff000001ffff0001, 0xff000001ff000000, 0xff000001ff000001,
+    0xff000001ff0001ff, 0xff000001ff000101, 0xff000001ff01ff00, 0xff000001ff010001,
+    0xff00000100ffffff, 0xff00000100ffff01, 0xff00000100ff00ff, 0xff00000100ff0000,
+    0xff00000100ff01ff, 0xff00000100ff0101, 0xff0000010000ff00, 0xff00000100000000,
+    0xff00000100000001, 0xff000001000001ff, 0xff00000100000100, 0xff0000010001ff00,
+    0xff000001000100ff, 0xff00000100010000, 0xff000001000101ff, 0xff00000100010100,
+    0xff00000100010101, 0xff00000101ff0001, 0xff00000101ff0101, 0xff0000010100ff01,
+    0xff00000101000000, 0xff000001010100ff, 0xff00000101010100, 0xff0001ffff00ff00,
+    0xff0001ffff000001, 0xff0001ffff010000, 0xff0001ff00ffff00, 0xff0001ff00ff00ff,
+    0xff0001ff00ff0001, 0xff0001ff00ff0100, 0xff0001ff0000ffff, 0xff0001ff00000000,
+    0xff0001ff000001ff, 0xff0001ff00000101, 0xff0001ff0001ffff, 0xff0001ff0001ff00,
+    0xff0001ff000100ff, 0xff0001ff00010001, 0xff0001ff00010100, 0xff0001ff01ff0000,
+    0xff0001ff0100ff00, 0xff0001ff010000ff, 0xff0001ff01010000, 0xff000100ff00ffff,
+    0xff000100ff00ff01, 0xff000100ff000000, 0xff000100ff000101, 0xff000100ff01ff00,
+    0xff000100ff010000, 0xff00010000ffff01, 0xff00010000ff00ff, 0xff00010000ff0000,
+    0xff00010000ff01ff, 0xff0001000000ff00, 0xff000100000000ff, 0xff00010000000000,
+    0xff00010000000001, 0xff00010000000100, 0xff00010000000101, 0xff0001000001ffff,
+    0xff00010000010000, 0xff00010000010101, 0xff00010001ff0100, 0xff0001000100ff00,
+    0xff0001000100ff01, 0xff00010001000000, 0xff000100010001ff, 0xff0001000101ff00,
+    0xff00010001010001, 0xff00010001010100, 0xff000101ffff0100, 0xff000101ff000001,
+    0xff000101ff0100ff, 0xff000101ff010001, 0xff00010100ff00ff, 0xff00010100ff0001,
+    0xff00010100ff0100, 0xff0001010000ffff, 0xff0001010000ff01, 0xff00010100000000,
+    0xff000101000001ff, 0xff0001010001ff00, 0xff00010100010001, 0xff00010100010100,
+    0xff00010101ff0000, 0xff0001010100ff00, 0xff00010101000001, 0xff00010101000101,
+    0xff01ffffffffffff, 0xff01ffffffffff01, 0xff01ffffffff01ff, 0xff01ffffffff0101,
+    0xff01ffffff000000, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff010000,
+    0xff01ffffff0101ff, 0xff01ffffff010101, 0xff01ffff00ff0000, 0xff01ffff0000ff00,
+    0xff01ffff00000100, 0xff01ffff0001ff00, 0xff01ffff00010000, 0xff01ffff01ffffff,
+    0xff01ffff01ffff01, 0xff01ffff01ff01ff, 0xff01ffff01ff0101, 0xff01ffff01000000,
+    0xff01ffff0101ffff, 0xff01ffff0101ff01, 0xff01ffff01010000, 0xff01ffff010101ff,
+    0xff01ffff01010101, 0xff01ff00ffff0000, 0xff01ff00ff00ff00, 0xff01ff00ff0000ff,
+    0xff01ff00ff000100, 0xff01ff00ff010000, 0xff01ff0000ffff01, 0xff01ff0000ff00ff,
+    0xff01ff0000ff0100, 0xff01ff0000000000, 0xff01ff00000001ff, 0xff01ff0000000101,
+    0xff01ff000001ff00, 0xff01ff00000100ff, 0xff01ff0000010000, 0xff01ff0000010001,
+    0xff01ff0001ff0000, 0xff01ff000100ffff, 0xff01ff0001000001, 0xff01ff0001000100,
+    0xff01ff0001010000, 0xff01ff01ffffff00, 0xff01ff01ffff01ff, 0xff01ff01ffff0101,
+    0xff01ff01ff00ff00, 0xff01ff01ff000000, 0xff01ff01ff01ffff, 0xff01ff01ff01ff01,
+    0xff01ff01ff0101ff, 0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff010000ff00,
+    0xff01ff0100000001, 0xff01ff0100000100, 0xff01ff0100010000, 0xff01ff0101ffff00,
+    0xff01ff0101ff01ff, 0xff01ff0101ff0101, 0xff01ff010100ff00, 0xff01ff0101000000,
+    0xff01ff010101ffff, 0xff01ff010101ff01, 0xff01ff01010101ff, 0xff01ff0101010101,
+    0xff0100ffffff0000, 0xff0100ffff0000ff, 0xff0100ffff000001, 0xff0100ffff000100,
+    0xff0100ffff010000, 0xff0100ff00ff00ff, 0xff0100ff00ff0000, 0xff0100ff00ff0001,
+    0xff0100ff00ff0100, 0xff0100ff0000ff01, 0xff0100ff00000000, 0xff0100ff000001ff,
+    0xff0100ff00000101, 0xff0100ff00010001, 0xff0100ff01ff0000, 0xff0100ff0100ff00,
+    0xff0100ff010000ff, 0xff0100ff01000100, 0xff0100ff0101ff00, 0xff0100ff01010000,
+    0xff010000ffff0100, 0xff010000ff000000, 0xff010000ff01ff00, 0xff010000ff010100,
+    0xff01000000ffffff, 0xff01000000ff0000, 0xff01000000ff01ff, 0xff0100000000ff00,
+    0xff010000000000ff, 0xff01000000000000, 0xff01000000000100, 0xff0100000001ff01,
+    0xff01000000010000, 0xff010000000101ff, 0xff01000001ff0100, 0xff0100000100ffff,
+    0xff010000010000ff, 0xff01000001000000, 0xff010000010001ff, 0xff01000001000101,
+    0xff0100000101ff00, 0xff010000010100ff, 0xff01000001010001, 0xff01000001010100,
+    0xff010001ffff0000, 0xff010001ff00ffff, 0xff010001ff00ff01, 0xff010001ff000100,
+    0xff010001ff010000, 0xff01000100ffff00, 0xff01000100ff0100, 0xff01000100000000,
+    0xff0100010001ffff, 0xff0100010001ff00, 0xff01000100010100, 0xff01000101ff00ff,
+    0xff01000101ff0001, 0xff0100010100ffff, 0xff01000101000101, 0xff0101ffffffffff,
+    0xff0101ffffffff01, 0xff0101ffffff01ff, 0xff0101ffffff0101, 0xff0101ffff000000,
+    0xff0101ffff01ffff, 0xff0101ffff01ff01, 0xff0101ffff0101ff, 0xff0101ffff010101,
+    0xff0101ff00ff0000, 0xff0101ff0000ff00, 0xff0101ff000000ff, 0xff0101ff00010000,
+    0xff0101ff01ffffff, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101,
+    0xff0101ff0101ffff, 0xff0101ff0101ff01, 0xff0101ff010101ff, 0xff0101ff01010101,
+    0xff010100ffff0100, 0xff010100ff00ff00, 0xff010100ff0000ff, 0xff010100ff000100,
+    0xff010100ff010000, 0xff01010000ff0001, 0xff01010000ff0100, 0xff0101000000ff01,
+    0xff01010000000000, 0xff0101000001ff00, 0xff010100000100ff, 0xff01010000010001,
+    0xff01010000010100, 0xff01010001ff0000, 0xff0101000100ffff, 0xff01010001000001,
+    0xff01010001000100, 0xff010100010100ff, 0xff01010001010000, 0xff010101ffffffff,
+    0xff010101ffffff01, 0xff010101ffff01ff, 0xff010101ffff0101, 0xff010101ff01ffff,
+    0xff010101ff01ff01, 0xff010101ff0101ff, 0xff010101ff010101, 0xff01010100ff0000,
+    0xff0101010000ff00, 0xff01010100000001, 0xff01010100000100, 0xff01010100010000,
+    0xff01010101ffffff, 0xff01010101ffff01, 0xff01010101ff01ff, 0xff01010101ff0101,
+    0xff01010101000000, 0xff0101010101ffff, 0xff0101010101ff01, 0xff010101010101ff,
+    0xff01010101010101, 0x00ffffffffff0000, 0x00ffffffff00ff00, 0x00ffffffff000001,
+    0x00ffffffff010000, 0x00ffffff00ff0100, 0x00ffffff0000ff01, 0x00ffffff00000000,
+    0x00ffffff000001ff, 0x00ffffff00000101, 0x00ffffff0001ff00, 0x00ffffff000100ff,
+    0x00ffffff00010001, 0x00ffffff010000ff, 0x00ffffff01000100, 0x00ffffff0101ff00,
+    0x00ffffff01010001, 0x00ffff00ffffffff, 0x00ffff00ffffff00, 0x00ffff00ffff00ff,
+    0x00ffff00ffff0001, 0x00ffff00ffff0100, 0x00ffff00ff00ff01, 0x00ffff00ff000000,
+    0x00ffff00ff000001, 0x00ffff00ff0001ff, 0x00ffff00ff000101, 0x00ffff00ff01ff00,
+    0x00ffff00ff010001, 0x00ffff00ff010100, 0x00ffff0000ff0000, 0x00ffff0000ff01ff,
+    0x00ffff0000ff0101, 0x00ffff000000ff00, 0x00ffff00000000ff, 0x00ffff0000000000,
+    0x00ffff0000000001, 0x00ffff0000000100, 0x00ffff0000000101, 0x00ffff0000010000,
+    0x00ffff00000101ff, 0x00ffff0000010101, 0x00ffff0001ffff00, 0x00ffff0001ff00ff,
+    0x00ffff0001ff0001, 0x00ffff000100ffff, 0x00ffff000100ff01, 0x00ffff0001000000,
+    0x00ffff000101ffff, 0x00ffff000101ff00, 0x00ffff000101ff01, 0x00ffff01ffff0000,
+    0x00ffff01ff00ff00, 0x00ffff01ff0000ff, 0x00ffff01ff000001, 0x00ffff01ff010000,
+    0x00ffff0100ffff00, 0x00ffff010000ff01, 0x00ffff0100000000, 0x00ffff0100000101,
+    0x00ffff01000100ff, 0x00ffff0100010100, 0x00ffff0101ff0100, 0x00ffff01010000ff,
+    0x00ffff0101010000, 0x00ff00ffffffff00, 0x00ff00ffff000000, 0x00ff00ffff000100,
+    0x00ff00ffff010100, 0x00ff00ff00ff0000, 0x00ff00ff00ff01ff, 0x00ff00ff00ff0101,
+    0x00ff00ff0000ff00, 0x00ff00ff000000ff, 0x00ff00ff00000000, 0x00ff00ff00000001,
+    0x00ff00ff0001ff00, 0x00ff00ff0001ff01, 0x00ff00ff00010000, 0x00ff00ff000101ff,
+    0x00ff00ff00010101, 0x00ff00ff01ffff00, 0x00ff00ff01ff0001, 0x00ff00ff01ff0100,
+    0x00ff00ff0100ffff, 0x00ff00ff0100ff01, 0x00ff00ff01000000, 0x00ff00ff0101ffff,
+    0x00ff00ff0101ff00, 0x00ff00ff01010100, 0x00ff0000ffffff00, 0x00ff0000ffffff01,
+    0x00ff0000ffff0000, 0x00ff0000ffff0101, 0x00ff0000ff00ff00, 0x00ff0000ff0000ff,
+    0x00ff0000ff000000, 0x00ff0000ff000001, 0x00ff0000ff000100, 0x00ff0000ff01ffff,
+    0x00ff0000ff010000, 0x00ff0000ff010101, 0x00ff000000ffff00, 0x00ff000000ff00ff,
+    0x00ff000000ff0000, 0x00ff000000ff0001, 0x00ff000000ff0100, 0x00ff00000000ffff,
+    0x00ff00000000ff00, 0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000001,
+    0x00ff0000000001ff, 0x00ff000000000100, 0x00ff00000001ff00, 0x00ff0000000100ff,
+    0x00ff000000010000, 0x00ff000000010001, 0x00ff000000010100, 0x00ff000001ffff01,
+    0x00ff000001ff00ff, 0x00ff000001ff0000, 0x00ff000001ff01ff, 0x00ff00000100ff00,
+    0x00ff0000010000ff, 0x00ff000001000000, 0x00ff000001000001, 0x00ff000001000100,
+    0x00ff000001000101, 0x00ff000001010000, 0x00ff0000010101ff, 0x00ff000001010101,
+    0x00ff0001ffffff00, 0x00ff0001ffff0000, 0x00ff0001ffff0100, 0x00ff0001ff0000ff,
+    0x00ff0001ff000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101, 0x00ff0001ff01ff00,
+    0x00ff0001ff0100ff, 0x00ff0001ff010100, 0x00ff000100ffffff, 0x00ff000100ffff01,
+    0x00ff000100ff0000, 0x00ff000100ff01ff, 0x00ff00010000ffff, 0x00ff00010000ff00,
+    0x00ff00010000ff01, 0x00ff000100000000, 0x00ff000100000001, 0x00ff000100000100,
+    0x00ff00010001ff01, 0x00ff000100010000, 0x00ff0001000101ff, 0x00ff000101ffff00,
+    0x00ff000101ff0000, 0x00ff000101ff0101, 0x00ff0001010000ff, 0x00ff000101000000,
+    0x00ff00010101ff00, 0x00ff0001010100ff, 0x00ff000101010001, 0x00ff01ffffff0000,
+    0x00ff01ffff00ff00, 0x00ff01ffff000000, 0x00ff01ffff000101, 0x00ff01ffff010000,
+    0x00ff01ff00ffff01, 0x00ff01ff00ff0100, 0x00ff01ff0000ffff, 0x00ff01ff00000000,
+    0x00ff01ff000001ff, 0x00ff01ff0001ff00, 0x00ff01ff000100ff, 0x00ff01ff00010001,
+    0x00ff01ff00010100, 0x00ff01ff01ff0000, 0x00ff01ff0100ff00, 0x00ff01ff010000ff,
+    0x00ff01ff01000001, 0x00ff01ff01000100, 0x00ff01ff01010000, 0x00ff0100ffffff00,
+    0x00ff0100ffff0000, 0x00ff0100ffff0001, 0x00ff0100ffff0101, 0x00ff0100ff00ffff,
+    0x00ff0100ff0000ff, 0x00ff0100ff000000, 0x00ff0100ff0001ff, 0x00ff0100ff01ff00,
+    0x00ff0100ff0100ff, 0x00ff0100ff010001, 0x00ff010000ffffff, 0x00ff010000ff0000,
+    0x00ff010000ff0101, 0x00ff01000000ff00, 0x00ff01000000ff01, 0x00ff0100000000ff,
+    0x00ff010000000000, 0x00ff010000000001, 0x00ff010000000100, 0x00ff01000001ffff,
+    0x00ff01000001ff01, 0x00ff010000010000, 0x00ff010000010001, 0x00ff010000010101,
+    0x00ff010001ff0001, 0x00ff010001ff0100, 0x00ff01000100ff01, 0x00ff010001000000,
+    0x00ff010001000001, 0x00ff0100010001ff, 0x00ff01000101ff00, 0x00ff0100010100ff,
+    0x00ff010001010001, 0x00ff010001010100, 0x00ff0101ff000001, 0x00ff010100ff00ff,
+    0x00ff010100ff0001, 0x00ff010100ff0100, 0x00ff010100000000, 0x00ff0101000001ff,
+    0x00ff010100000101, 0x00ff0101000100ff, 0x00ff010100010100, 0x00ff0101010000ff,
+    0x00ff010101010000, 0x0000ffffffffff00, 0x0000ffffffff00ff, 0x0000ffffffff0000,
+    0x0000ffffffff0001, 0x0000ffffffff0100, 0x0000ffffff00ff01, 0x0000ffffff000000,
+    0x0000ffffff000101, 0x0000ffffff01ff00, 0x0000ffffff0100ff, 0x0000ffffff010100,
+    0x0000ffff00ffffff, 0x0000ffff00ff0000, 0x0000ffff00ff01ff, 0x0000ffff0000ff00,
+    0x0000ffff000000ff, 0x0000ffff00000000, 0x0000ffff00000001, 0x0000ffff00000100,
+    0x0000ffff00010000, 0x0000ffff000101ff, 0x0000ffff01ff0001, 0x0000ffff01ff0100,
+    0x0000ffff01000000, 0x0000ffff010001ff, 0x0000ffff0101ffff, 0x0000ffff0101ff00,
+    0x0000ffff01010001, 0x0000ffff01010100, 0x0000ff00ffff0000, 0x0000ff00ffff01ff,
+    0x0000ff00ffff0100, 0x0000ff00ffff0101, 0x0000ff00ff00ff00, 0x0000ff00ff0000ff,
+    0x0000ff00ff000000, 0x0000ff00ff000001, 0x0000ff00ff0001ff, 0x0000ff00ff000100,
+    0x0000ff00ff01ffff, 0x0000ff00ff010000, 0x0000ff00ff010001, 0x0000ff00ff0101ff,
+    0x0000ff00ff010101, 0x0000ff0000ffff00, 0x0000ff0000ff00ff, 0x0000ff0000ff0000,
+    0x0000ff0000ff0001, 0x0000ff0000ff0100, 0x0000ff000000ffff, 0x0000ff000000ff00,
+    0x0000ff000000ff01, 0x0000ff00000000ff, 0x0000ff0000000000, 0x0000ff0000000001,
+    0x0000ff00000001ff, 0x0000ff0000000100, 0x0000ff0000000101, 0x0000ff000001ff00,
+    0x0000ff00000100ff, 0x0000ff0000010000, 0x0000ff0000010001, 0x0000ff0000010100,
+    0x0000ff0001ffff01, 0x0000ff0001ff0000, 0x0000ff000100ff00, 0x0000ff00010000ff,
+    0x0000ff0001000000, 0x0000ff0001000001, 0x0000ff0001000100, 0x0000ff000101ffff,
+    0x0000ff0001010000, 0x0000ff0001010101, 0x0000ff01ffffff00, 0x0000ff01ffff0001,
+    0x0000ff01ff00ff01, 0x0000ff01ff000000, 0x0000ff01ff000101, 0x0000ff01ff01ff00,
+    0x0000ff01ff0100ff, 0x0000ff0100ffff01, 0x0000ff0100ff0000, 0x0000ff0100ff0101,
+    0x0000ff010000ff00, 0x0000ff01000000ff, 0x0000ff0100000000, 0x0000ff0100000001,
+    0x0000ff0100000100, 0x0000ff010001ff01, 0x0000ff0100010000, 0x0000ff0101ff0000,
+    0x0000ff010100ffff, 0x0000ff010100ff01, 0x0000ff0101000000, 0x0000ff0101000100,
+    0x0000ff0101000101, 0x0000ff01010100ff, 0x000000ffffff00ff, 0x000000ffffff0000,
+    0x000000ffff00ff00, 0x000000ffff0000ff, 0x000000ffff000000, 0x000000ffff000001,
+    0x000000ffff0001ff, 0x000000ffff000100, 0x000000ffff01ff00, 0x000000ffff010000,
+    0x000000ffff0101ff, 0x000000ffff010101, 0x000000ff00ffff00, 0x000000ff00ff00ff,
+    0x000000ff00ff0000, 0x000000ff00ff0001, 0x000000ff00ff0100, 0x000000ff00ff0101,
+    0x000000ff0000ffff, 0x000000ff0000ff00, 0x000000ff000000ff, 0x000000ff00000000,
+    0x000000ff00000001, 0x000000ff000001ff, 0x000000ff00000100, 0x000000ff00000101,
+    0x000000ff0001ff00, 0x000000ff0001ff01, 0x000000ff000100ff, 0x000000ff00010000,
+    0x000000ff00010001, 0x000000ff00010100, 0x000000ff01ffffff, 0x000000ff01ff01ff,
+    0x000000ff01ff0101, 0x000000ff0100ff00, 0x000000ff010000ff, 0x000000ff01000000,
+    0x000000ff01000001, 0x000000ff01000100, 0x000000ff0101ff00, 0x000000ff010100ff,
+    0x000000ff01010000, 0x000000ff01010101, 0x00000000ffffff00, 0x00000000ffffff01,
+    0x00000000ffff00ff, 0x00000000ffff0000, 0x00000000ffff0001, 0x00000000ffff0100,
+    0x00000000ff00ffff, 0x00000000ff00ff00, 0x00000000ff00ff01, 0x00000000ff0000ff,
+    0x00000000ff000000, 0x00000000ff000001, 0x00000000ff000100, 0x00000000ff000101,
+    0x00000000ff01ff00, 0x00000000ff0100ff, 0x00000000ff010000, 0x00000000ff010001,
+    0x00000000ff010100, 0x0000000000ffffff, 0x0000000000ffff00, 0x0000000000ffff01,
+    0x0000000000ff00ff, 0x0000000000ff0000, 0x0000000000ff0001, 0x0000000000ff01ff,
+    0x0000000000ff0100, 0x000000000000ffff, 0x000000000000ff00, 0x000000000000ff01,
+    0x00000000000000ff, 0x0000000000000000, 0x0000000000000001, 0x00000000000001ff,
+    0x0000000000000100, 0x0000000000000101, 0x000000000001ffff, 0x000000000001ff00,
+    0x00000000000100ff, 0x0000000000010000, 0x0000000000010001, 0x00000000000101ff,
+    0x0000000000010100, 0x0000000000010101, 0x0000000001ffff00, 0x0000000001ff00ff,
+    0x0000000001ff0000, 0x0000000001ff0100, 0x0000000001ff0101, 0x000000000100ffff,
+    0x000000000100ff00, 0x00000000010000ff, 0x0000000001000000, 0x0000000001000001,
+    0x00000000010001ff, 0x0000000001000100, 0x000000000101ff00, 0x00000000010100ff,
+    0x0000000001010000, 0x0000000001010001, 0x0000000001010100, 0x00000001ffffffff,
+    0x00000001ffffff00, 0x00000001ffffff01, 0x00000001ffff00ff, 0x00000001ffff0001,
+    0x00000001ffff01ff, 0x00000001ffff0100, 0x00000001ff00ff00, 0x00000001ff0000ff,
+    0x00000001ff000000, 0x00000001ff0001ff, 0x00000001ff000100, 0x00000001ff01ffff,
+    0x00000001ff01ff00, 0x00000001ff01ff01, 0x00000001ff0100ff, 0x00000001ff010000,
+    0x00000001ff010001, 0x00000001ff0101ff, 0x00000001ff010100, 0x0000000100ffff00,
+    0x0000000100ff0000, 0x0000000100ff0001, 0x0000000100ff01ff, 0x0000000100ff0100,
+    0x0000000100ff0101, 0x000000010000ffff, 0x000000010000ff00, 0x000000010000ff01,
+    0x00000001000000ff, 0x0000000100000000, 0x0000000100000001, 0x00000001000001ff,
+    0x0000000100000100, 0x0000000100000101, 0x000000010001ff00, 0x00000001000100ff,
+    0x0000000100010000, 0x0000000100010100, 0x0000000101ffff01, 0x0000000101ff0000,
+    0x0000000101ff0001, 0x0000000101ff01ff, 0x0000000101ff0100, 0x0000000101ff0101,
+    0x000000010100ff00, 0x0000000101000000, 0x0000000101000101, 0x000000010101ff01,
+    0x0000000101010000, 0x0000000101010001, 0x00000001010101ff, 0x0000000101010100,
+    0x000001ffffff00ff, 0x000001ffffff0000, 0x000001ffffff0001, 0x000001ffffff0100,
+    0x000001ffff00ffff, 0x000001ffff000000, 0x000001ffff0001ff, 0x000001ffff01ff00,
+    0x000001ffff010101, 0x000001ff00ff0000, 0x000001ff00ff01ff, 0x000001ff00ff0101,
+    0x000001ff0000ff00, 0x000001ff000000ff, 0x000001ff00000000, 0x000001ff00000001,
+    0x000001ff000001ff, 0x000001ff00000100, 0x000001ff0001ffff, 0x000001ff0001ff01,
+    0x000001ff000100ff, 0x000001ff00010000, 0x000001ff01ffff01, 0x000001ff01ff0100,
+    0x000001ff0100ffff, 0x000001ff0100ff01, 0x000001ff01000000, 0x000001ff010001ff,
+    0x000001ff0101ff00, 0x000001ff01010100, 0x00000100ffffff00, 0x00000100ffffff01,
+    0x00000100ffff0000, 0x00000100ffff0101, 0x00000100ff00ff00, 0x00000100ff0000ff,
+    0x00000100ff000000, 0x00000100ff000001, 0x00000100ff000100, 0x00000100ff010000,
+    0x0000010000ffff00, 0x0000010000ff00ff, 0x0000010000ff0000, 0x0000010000ff0001,
+    0x0000010000ff0100, 0x000001000000ffff, 0x000001000000ff00, 0x000001000000ff01,
+    0x00000100000000ff, 0x0000010000000000, 0x0000010000000001, 0x00000100000001ff,
+    0x0000010000000100, 0x0000010000000101, 0x000001000001ff00, 0x00000100000100ff,
+    0x0000010000010000, 0x0000010000010001, 0x0000010000010100, 0x0000010001ffff00,
+    0x0000010001ff0000, 0x0000010001ff0100, 0x000001000100ff00, 0x00000100010000ff,
+    0x0000010001000000, 0x0000010001000001, 0x00000100010001ff, 0x0000010001000100,
+    0x0000010001010000, 0x00000101ffff00ff, 0x00000101ffff01ff, 0x00000101ff000000,
+    0x00000101ff000101, 0x00000101ff01ffff, 0x00000101ff010000, 0x00000101ff010001,
+    0x00000101ff010100, 0x0000010100ff0000, 0x0000010100ff01ff, 0x0000010100ff0100,
+    0x000001010000ff00, 0x0000010100000000, 0x0000010100000001, 0x00000101000001ff,
+    0x0000010100000100, 0x000001010001ff01, 0x0000010100010000, 0x00000101000101ff,
+    0x0000010100010101, 0x0000010101ffff00, 0x0000010101ff0101, 0x000001010100ff01,
+    0x0000010101000000, 0x0000010101000001, 0x00000101010001ff, 0x0000010101000101,
+    0x000001010101ff00, 0x0001ffffffff0000, 0x0001ffffff0000ff, 0x0001ffffff000001,
+    0x0001ffffff000100, 0x0001ffffff010000, 0x0001ffff00ff00ff, 0x0001ffff0000ffff,
+    0x0001ffff00000000, 0x0001ffff00000001, 0x0001ffff000001ff, 0x0001ffff00000101,
+    0x0001ffff0001ff00, 0x0001ffff000100ff, 0x0001ffff00010001, 0x0001ffff00010100,
+    0x0001ffff01ffff00, 0x0001ffff01000001, 0x0001ffff01010000, 0x0001ff00ffffff00,
+    0x0001ff00ffff00ff, 0x0001ff00ffff0001, 0x0001ff00ffff0100, 0x0001ff00ff00ff01,
+    0x0001ff00ff000000, 0x0001ff00ff01ff00, 0x0001ff00ff01ff01, 0x0001ff00ff010001,
+    0x0001ff00ff010100, 0x0001ff0000ff0000, 0x0001ff0000ff0100, 0x0001ff000000ff00,
+    0x0001ff0000000000, 0x0001ff0000000001, 0x0001ff0000000100, 0x0001ff0000010000,
+    0x0001ff0000010001, 0x0001ff0000010101, 0x0001ff0001ff00ff, 0x0001ff0001ff0101,
+    0x0001ff000100ff01, 0x0001ff0001000000, 0x0001ff000101ff00, 0x0001ff0001010001,
+    0x0001ff0001010100, 0x0001ff01ff00ff00, 0x0001ff01ff000001, 0x0001ff01ff000100,
+    0x0001ff0100ffffff, 0x0001ff0100ffff00, 0x0001ff0100ff0001, 0x0001ff0100000000,
+    0x0001ff0100000001, 0x0001ff01000001ff, 0x0001ff010001ffff, 0x0001ff0101ff0000,
+    0x0001ff010100ff00, 0x0001ff0101000001, 0x0001ff0101010000, 0x000100ffff00ff00,
+    0x000100ffff00ff01, 0x000100ffff000000, 0x000100ffff000001, 0x000100ffff000101,
+    0x000100ffff01ff00, 0x000100ffff010001, 0x000100ffff010100, 0x000100ff00ffffff,
+    0x000100ff00ffff01, 0x000100ff00ff0000, 0x000100ff00ff01ff, 0x000100ff00ff0101,
+    0x000100ff0000ff00, 0x000100ff000000ff, 0x000100ff00000000, 0x000100ff00000001,
+    0x000100ff00000100, 0x000100ff00000101, 0x000100ff0001ffff, 0x000100ff0001ff01,
+    0x000100ff00010000, 0x000100ff01ff00ff, 0x000100ff01ff0000, 0x000100ff01ff0100,
+    0x000100ff0100ffff, 0x000100ff0100ff01, 0x000100ff010000ff, 0x000100ff01000000,
+    0x000100ff01000001, 0x000100ff010001ff, 0x000100ff01000101, 0x000100ff0101ff00,
+    0x000100ff010100ff, 0x000100ff01010100, 0x00010000ffff0000, 0x00010000ffff01ff,
+    0x00010000ffff0101, 0x00010000ff00ff00, 0x00010000ff000000, 0x00010000ff000001,
+    0x00010000ff000100, 0x0001000000ff00ff, 0x0001000000ff0000, 0x0001000000ff0001,
+    0x0001000000ff0100, 0x000100000000ffff, 0x000100000000ff00, 0x00010000000000ff,
+    0x0001000000000000, 0x0001000000000001, 0x0001000000000100, 0x000100000001ff00,
+    0x00010000000100ff, 0x0001000000010000, 0x0001000000010001, 0x0001000000010100,
+    0x0001000001ff0001, 0x0001000001ff0100, 0x0001000001ff0101, 0x000100000100ff00,
+    0x0001000001000000, 0x0001000001000001, 0x0001000001000100, 0x0001000001000101,
+    0x000100000101ff01, 0x0001000001010000, 0x0001000001010001, 0x00010000010101ff,
+    0x00010001ffffff01, 0x00010001ffff0100, 0x00010001ff000000, 0x00010001ff01ffff,
+    0x00010001ff010001, 0x00010001ff0101ff, 0x00010001ff010100, 0x0001000100ffffff,
+    0x0001000100ff0000, 0x0001000100ff01ff, 0x0001000100ff0101, 0x000100010000ff00,
+    0x00010001000000ff, 0x0001000100000000, 0x0001000100000001, 0x00010001000001ff,
+    0x0001000100000101, 0x000100010001ffff, 0x0001000100010000, 0x00010001000101ff,
+    0x0001000101ffffff, 0x0001000101ffff01, 0x0001000101ff0000, 0x0001000101ff0101,
+    0x00010001010000ff, 0x0001000101000001, 0x00010001010001ff, 0x0001000101000100,
+    0x000100010101ffff, 0x00010001010100ff, 0x0001000101010001, 0x0001000101010101,
+    0x000101ffff000001, 0x000101ffff000100, 0x000101ffff010000, 0x000101ff00ffff00,
+    0x000101ff0000ff01, 0x000101ff00000000, 0x000101ff00000101, 0x000101ff0001ff00,
+    0x000101ff00010100, 0x000101ff01ff0000, 0x000101ff0100ff00, 0x000101ff010001ff,
+    0x000101ff01010001, 0x00010100ffffff00, 0x00010100ffff00ff, 0x00010100ff00ffff,
+    0x00010100ff000000, 0x00010100ff01ff00, 0x00010100ff0100ff, 0x00010100ff010001,
+    0x00010100ff010100, 0x0001010000ffffff, 0x0001010000ffff00, 0x0001010000ff0000,
+    0x0001010000ff0001, 0x0001010000ff01ff, 0x000101000000ff00, 0x00010100000000ff,
+    0x0001010000000000, 0x0001010000000001, 0x0001010000000100, 0x000101000001ffff,
+    0x0001010000010000, 0x0001010000010101, 0x0001010001ffff01, 0x0001010001ff00ff,
+    0x0001010001ff0101, 0x0001010001000000, 0x000101000101ff00, 0x00010100010100ff,
+    0x0001010001010000, 0x0001010001010100, 0x00010101ff00ff00, 0x00010101ff000001,
+    0x00010101ff0001ff, 0x0001010100ffff00, 0x0001010100ff00ff, 0x0001010100ff0100,
+    0x000101010000ffff, 0x0001010100000000, 0x00010101000001ff, 0x0001010100000101,
+    0x00010101000100ff, 0x0001010100010000, 0x0001010100010100, 0x0001010101ff0001,
+    0x00010101010000ff, 0x00010101010001ff, 0x0001010101000101, 0x0001010101010001,
+    0x01ffffffffffffff, 0x01ffffffffffff01, 0x01ffffffffff01ff, 0x01ffffffffff0101,
+    0x01ffffffff01ffff, 0x01ffffffff01ff01, 0x01ffffffff0101ff, 0x01ffffffff010101,
+    0x01ffffff00ff0000, 0x01ffffff0000ffff, 0x01ffffff0000ff00, 0x01ffffff000000ff,
+    0x01ffffff00000001, 0x01ffffff00000100, 0x01ffffff00010000, 0x01ffffff01ffffff,
+    0x01ffffff01ffff01, 0x01ffffff01ff01ff, 0x01ffffff01ff0101, 0x01ffffff01000000,
+    0x01ffffff0101ffff, 0x01ffffff0101ff01, 0x01ffffff010101ff, 0x01ffffff01010101,
+    0x01ffff00ffff0000, 0x01ffff00ff00ff00, 0x01ffff00ff0000ff, 0x01ffff00ff000001,
+    0x01ffff00ff000100, 0x01ffff00ff010000, 0x01ffff0000ffff00, 0x01ffff0000ff00ff,
+    0x01ffff0000ff0100, 0x01ffff000000ffff, 0x01ffff000000ff01, 0x01ffff0000000000,
+    0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000000100, 0x01ffff00000100ff,
+    0x01ffff0000010001, 0x01ffff0000010100, 0x01ffff0001ff0000, 0x01ffff0001ff0100,
+    0x01ffff00010000ff, 0x01ffff0001000001, 0x01ffff0001000100, 0x01ffff0001010000,
+    0x01ffff01ffffffff, 0x01ffff01ffffff01, 0x01ffff01ffff01ff, 0x01ffff01ffff0101,
+    0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff01ff01ff01, 0x01ffff01ff0101ff,
+    0x01ffff01ff010101, 0x01ffff010000ff00, 0x01ffff01000000ff, 0x01ffff0100000100,
+    0x01ffff0100010000, 0x01ffff0101ffffff, 0x01ffff0101ffff01, 0x01ffff0101ff01ff,
+    0x01ffff0101ff0101, 0x01ffff0101000000, 0x01ffff010101ffff, 0x01ffff010101ff01,
+    0x01ffff01010101ff, 0x01ffff0101010101, 0x01ff00ffff0000ff, 0x01ff00ffff000100,
+    0x01ff00ff00ffff00, 0x01ff00ff00ff00ff, 0x01ff00ff0000ff00, 0x01ff00ff00000000,
+    0x01ff00ff00000101, 0x01ff00ff0001ff00, 0x01ff00ff000100ff, 0x01ff00ff00010100,
+    0x01ff00ff010000ff, 0x01ff00ff01000100, 0x01ff0000ffffff00, 0x01ff0000ffff0100,
+    0x01ff0000ff00ff01, 0x01ff0000ff000000, 0x01ff0000ff000101, 0x01ff0000ff010001,
+    0x01ff0000ff010100, 0x01ff000000ffffff, 0x01ff000000ffff00, 0x01ff000000ff0000,
+    0x01ff000000ff01ff, 0x01ff00000000ff00, 0x01ff0000000000ff, 0x01ff000000000000,
+    0x01ff000000000001, 0x01ff000000000100, 0x01ff000000000101, 0x01ff000000010000,
+    0x01ff000000010001, 0x01ff0000000101ff, 0x01ff000000010101, 0x01ff000001ffff00,
+    0x01ff000001ff00ff, 0x01ff000001ff0001, 0x01ff000001ff0100, 0x01ff00000100ffff,
+    0x01ff00000100ff01, 0x01ff000001000000, 0x01ff0000010001ff, 0x01ff000001010001,
+    0x01ff0001ff00ff00, 0x01ff0001ff000001, 0x01ff0001ff000100, 0x01ff0001ff010000,
+    0x01ff000100ffff00, 0x01ff000100ff00ff, 0x01ff000100ff0100, 0x01ff000100ff0101,
+    0x01ff00010000ffff, 0x01ff000100000000, 0x01ff000100000100, 0x01ff000100000101,
+    0x01ff00010001ff00, 0x01ff000100010001, 0x01ff000100010101, 0x01ff000101ff0000,
+    0x01ff00010100ff00, 0x01ff000101000101, 0x01ff0001010100ff, 0x01ff01ffffffffff,
+    0x01ff01ffffffff01, 0x01ff01ffffff01ff, 0x01ff01ffffff0101, 0x01ff01ffff000000,
+    0x01ff01ffff01ffff, 0x01ff01ffff01ff01, 0x01ff01ffff0101ff, 0x01ff01ffff010101,
+    0x01ff01ff00ffff00, 0x01ff01ff00ff0000, 0x01ff01ff0000ff00, 0x01ff01ff000000ff,
+    0x01ff01ff00000100, 0x01ff01ff00010000, 0x01ff01ff00010100, 0x01ff01ff01ffffff,
+    0x01ff01ff01ffff01, 0x01ff01ff01ff01ff, 0x01ff01ff01ff0101, 0x01ff01ff01000000,
+    0x01ff01ff0101ffff, 0x01ff01ff0101ff01, 0x01ff01ff010101ff, 0x01ff01ff01010101,
+    0x01ff0100ffff0000, 0x01ff0100ffff0001, 0x01ff0100ff00ff00, 0x01ff0100ff0000ff,
+    0x01ff0100ff000001, 0x01ff0100ff010000, 0x01ff010000ffff00, 0x01ff010000ff00ff,
+    0x01ff010000ff0001, 0x01ff010000ff0100, 0x01ff01000000ffff, 0x01ff01000000ff01,
+    0x01ff010000000000, 0x01ff010000000101, 0x01ff01000001ff00, 0x01ff0100000100ff,
+    0x01ff010001ff0000, 0x01ff010001000001, 0x01ff010001000100, 0x01ff010001010000,
+    0x01ff0101ffffffff, 0x01ff0101ffffff01, 0x01ff0101ffff01ff, 0x01ff0101ffff0101,
+    0x01ff0101ff000000, 0x01ff0101ff01ffff, 0x01ff0101ff01ff01, 0x01ff0101ff0101ff,
+    0x01ff0101ff010101, 0x01ff010100ff0000, 0x01ff01010000ff00, 0x01ff0101000000ff,
+    0x01ff010100000001, 0x01ff010101ffffff, 0x01ff010101ffff01, 0x01ff010101ff01ff,
+    0x01ff010101ff0101, 0x01ff010101000000, 0x01ff01010101ffff, 0x01ff01010101ff01,
+    0x01ff0101010101ff, 0x01ff010101010101, 0x0100ffffffff0000, 0x0100ffffff00ff00,
+    0x0100ffffff000001, 0x0100ffffff0001ff, 0x0100ffffff000100, 0x0100ffffff010000,
+    0x0100ffff00ffff00, 0x0100ffff00ff0001, 0x0100ffff00ff0100, 0x0100ffff00000000,
+    0x0100ffff000001ff, 0x0100ffff00000101, 0x0100ffff00010100, 0x0100ffff00010101,
+    0x0100ffff01ff0000, 0x0100ffff0100ff00, 0x0100ffff010000ff, 0x0100ffff01000001,
+    0x0100ffff01000100, 0x0100ffff01010000, 0x0100ff00ffffff00, 0x0100ff00ffff00ff,
+    0x0100ff00ffff0001, 0x0100ff00ffff0100, 0x0100ff00ff00ffff, 0x0100ff00ff000000,
+    0x0100ff00ff0001ff, 0x0100ff00ff000101, 0x0100ff00ff01ff00, 0x0100ff00ff0100ff,
+    0x0100ff00ff010001, 0x0100ff00ff010100, 0x0100ff0000ffffff, 0x0100ff0000ff0000,
+    0x0100ff000000ffff, 0x0100ff000000ff00, 0x0100ff00000000ff, 0x0100ff0000000000,
+    0x0100ff0000000001, 0x0100ff0000000100, 0x0100ff000001ff01, 0x0100ff0000010000,
+    0x0100ff0001ff00ff, 0x0100ff0001ff0001, 0x0100ff000100ff01, 0x0100ff0001000000,
+    0x0100ff00010001ff, 0x0100ff000101ff00, 0x0100ff00010100ff, 0x0100ff0001010001,
+    0x0100ff0001010100, 0x0100ff01ffff0000, 0x0100ff01ff00ff00, 0x0100ff01ff0000ff,
+    0x0100ff01ff000100, 0x0100ff01ff010000, 0x0100ff0100ff00ff, 0x0100ff0100ff0001,
+    0x0100ff0100ff0100, 0x0100ff010000ffff, 0x0100ff010000ff01, 0x0100ff0100000000,
+    0x0100ff01000001ff, 0x0100ff0100010001, 0x0100ff0100010100, 0x0100ff0101ff0000,
+    0x0100ff01010000ff, 0x0100ff0101000001, 0x0100ff0101010100, 0x010000ffffffff00,
+    0x010000ffffff00ff, 0x010000ffffff0001, 0x010000ffff00ffff, 0x010000ffff000000,
+    0x010000ffff0001ff, 0x010000ffff010001, 0x010000ff00ffffff, 0x010000ff00ff0101,
+    0x010000ff0000ff00, 0x010000ff000000ff, 0x010000ff00000000, 0x010000ff00000001,
+    0x010000ff000001ff, 0x010000ff00000100, 0x010000ff0001ffff, 0x010000ff0001ff00,
+    0x010000ff0001ff01, 0x010000ff00010000, 0x010000ff01ff00ff, 0x010000ff01ff0001,
+    0x010000ff0100ff01, 0x010000ff010000ff, 0x010000ff01000000, 0x010000ff010001ff,
+    0x010000ff0101ff00, 0x010000ff01010100, 0x01000000ffffffff, 0x01000000ffff0000,
+    0x01000000ffff01ff, 0x01000000ffff0101, 0x01000000ff00ffff, 0x01000000ff00ff00,
+    0x01000000ff0000ff, 0x01000000ff000000, 0x01000000ff000001, 0x01000000ff000100,
+    0x01000000ff01ff00, 0x01000000ff010000, 0x01000000ff010100, 0x01000000ff010101,
+    0x0100000000ffff00, 0x0100000000ff00ff, 0x0100000000ff0000, 0x0100000000ff0001,
+    0x0100000000ff0100, 0x010000000000ffff, 0x010000000000ff00, 0x010000000000ff01,
+    0x01000000000000ff, 0x0100000000000000, 0x0100000000000001, 0x01000000000001ff,
+    0x0100000000000100, 0x0100000000000101, 0x010000000001ff00, 0x01000000000100ff,
+    0x0100000000010000, 0x0100000000010001, 0x0100000000010100, 0x0100000001ffff00,
+    0x0100000001ff0000, 0x0100000001ff01ff, 0x010000000100ff00, 0x010000000100ff01,
+    0x01000000010000ff, 0x0100000001000000, 0x0100000001000001, 0x0100000001000100,
+    0x0100000001000101, 0x010000000101ffff, 0x010000000101ff01, 0x0100000001010000,
+    0x01000000010101ff, 0x0100000001010101, 0x01000001ffffff00, 0x01000001ffff00ff,
+    0x01000001ff00ffff, 0x01000001ff000000, 0x01000001ff000100, 0x01000001ff01ffff,
+    0x01000001ff010001, 0x01000001ff010100, 0x0100000100ff0000, 0x0100000100ff01ff,
+    0x0100000100ff0100, 0x010000010000ff00, 0x010000010000ff01, 0x0100000100000000,
+    0x0100000100000001, 0x0100000100000100, 0x0100000100010000, 0x01000001000101ff,
+    0x0100000101ffff01, 0x0100000101ff00ff, 0x0100000101ff0100, 0x0100000101ff0101,
+    0x010000010100ff01, 0x01000001010000ff, 0x0100000101000000, 0x01000001010100ff,
+    0x0100000101010001, 0x0100000101010100, 0x010001ffffff0000, 0x010001ffff000001,
+    0x010001ffff000100, 0x010001ffff010000, 0x010001ff00ffff00, 0x010001ff00ff0001,
+    0x010001ff0000ffff, 0x010001ff0000ff01, 0x010001ff00000000, 0x010001ff00000001,
+    0x010001ff00000101, 0x010001ff000100ff, 0x010001ff00010000, 0x010001ff01ff0000,
+    0x010001ff0100ff00, 0x010001ff01000001, 0x010001ff01000100, 0x010001ff01010000,
+    0x01000100ffff00ff, 0x01000100ffff0001, 0x01000100ffff0100, 0x01000100ff00ffff,
+    0x01000100ff00ff01, 0x01000100ff000000, 0x01000100ff0001ff, 0x01000100ff000101,
+    0x01000100ff01ffff, 0x01000100ff01ff00, 0x01000100ff0100ff, 0x01000100ff010001,
+    0x0100010000ffffff, 0x0100010000ffff01, 0x0100010000ff0000, 0x0100010000ff01ff,
+    0x0100010000ff0101, 0x010001000000ff00, 0x01000100000000ff, 0x0100010000000000,
+    0x0100010000000001, 0x0100010000000100, 0x010001000001ff01, 0x0100010000010000,
+    0x0100010000010001, 0x0100010000010101, 0x0100010001ffff00, 0x0100010001ff00ff,
+    0x010001000100ffff, 0x010001000100ff01, 0x0100010001000000, 0x0100010001000101,
+    0x010001000101ff00, 0x0100010001010001, 0x01000101ffff0000, 0x01000101ff000000,
+    0x01000101ff010000, 0x0100010100ff00ff, 0x0100010100ff0001, 0x0100010100ff0100,
+    0x010001010000ffff, 0x0100010100000000, 0x01000101000001ff, 0x010001010001ff00,
+    0x0100010101ff0000, 0x010001010100ff00, 0x01000101010000ff, 0x0100010101000000,
+    0x0100010101000001, 0x0101ffffffffffff, 0x0101ffffffffff01, 0x0101ffffffff01ff,
+    0x0101ffffffff0101, 0x0101ffffff000000, 0x0101ffffff01ffff, 0x0101ffffff01ff01,
+    0x0101ffffff0101ff, 0x0101ffffff010101, 0x0101ffff00ff0000, 0x0101ffff0000ff00,
+    0x0101ffff000000ff, 0x0101ffff00000001, 0x0101ffff00000100, 0x0101ffff01ffffff,
+    0x0101ffff01ffff01, 0x0101ffff01ff01ff, 0x0101ffff01ff0101, 0x0101ffff01000000,
+    0x0101ffff0101ffff, 0x0101ffff0101ff01, 0x0101ffff010101ff, 0x0101ffff01010101,
+    0x0101ff00ffff0000, 0x0101ff00ffff0100, 0x0101ff00ff00ff00, 0x0101ff00ff0000ff,
+    0x0101ff00ff000001, 0x0101ff00ff000100, 0x0101ff00ff000101, 0x0101ff0000ff0001,
+    0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000000000, 0x0101ff00000001ff,
+    0x0101ff0000000101, 0x0101ff000001ff00, 0x0101ff00000100ff, 0x0101ff0001ff0000,
+    0x0101ff000100ffff, 0x0101ff000100ff01, 0x0101ff0001000001, 0x0101ff0001000100,
+    0x0101ff01ffffff01, 0x0101ff01ffff01ff, 0x0101ff01ffff0101, 0x0101ff01ff00ffff,
+    0x0101ff01ff000100, 0x0101ff01ff01ff01, 0x0101ff01ff0101ff, 0x0101ff01ff010101,
+    0x0101ff0100ff0000, 0x0101ff010000ff00, 0x0101ff0100000001, 0x0101ff0100000100,
+    0x0101ff0100010000, 0x0101ff0101ffffff, 0x0101ff0101ffff01, 0x0101ff0101ff01ff,
+    0x0101ff0101ff0101, 0x0101ff0101000000, 0x0101ff010101ffff, 0x0101ff010101ff01,
+    0x0101ff01010101ff, 0x0101ff0101010101, 0x010100ffff000100, 0x010100ffff010000,
+    0x010100ff00ffff00, 0x010100ff00ff00ff, 0x010100ff0000ffff, 0x010100ff000000ff,
+    0x010100ff00000000, 0x010100ff000001ff, 0x010100ff00000101, 0x010100ff0001ff00,
+    0x010100ff00010000, 0x010100ff00010001, 0x010100ff000101ff, 0x010100ff00010100,
+    0x010100ff01ff0000, 0x01010000ffff0001, 0x01010000ffff0100, 0x01010000ff00ffff,
+    0x01010000ff00ff01, 0x01010000ff000000, 0x01010000ff0001ff, 0x01010000ff010001,
+    0x01010000ff010100, 0x0101000000ffff01, 0x0101000000ff0000, 0x010100000000ff00,
+    0x01010000000000ff, 0x0101000000000000, 0x0101000000000001, 0x0101000000000100,
+    0x0101000000010000, 0x0101000000010101, 0x0101000001ffff00, 0x0101000001ff00ff,
+    0x0101000001ff0000, 0x0101000001ff0001, 0x0101000001ff0100, 0x010100000100ff01,
+    0x0101000001000000, 0x01010000010001ff, 0x01010001ffff0000, 0x01010001ff00ff00,
+    0x01010001ff000001, 0x01010001ff000101, 0x01010001ff01ff00, 0x01010001ff010000,
+    0x0101000100ff00ff, 0x0101000100ff0001, 0x0101000100ff0101, 0x010100010000ff01,
+    0x0101000100000000, 0x0101000100000001, 0x01010001000001ff, 0x010100010001ffff,
+    0x010100010001ff01, 0x0101000101ff0001, 0x010100010100ffff, 0x0101000101000000,
+    0x0101000101000001, 0x0101000101000100, 0x010100010101ff00, 0x01010001010100ff,
+    0x0101000101010001, 0x010101ffffffffff, 0x010101ffffffff01, 0x010101ffffff01ff,
+    0x010101ffffff0101, 0x010101ffff01ffff, 0x010101ffff01ff01, 0x010101ffff0101ff,
+    0x010101ffff010101, 0x010101ff0000ff00, 0x010101ff000000ff, 0x010101ff00000001,
+    0x010101ff00000100, 0x010101ff01ffffff, 0x010101ff01ffff01, 0x010101ff01ff01ff,
+    0x010101ff01ff0101, 0x010101ff01000000, 0x010101ff0101ffff, 0x010101ff0101ff01,
+    0x010101ff010101ff, 0x010101ff01010101, 0x01010100ffff0000, 0x01010100ff0000ff,
+    0x01010100ff000100, 0x01010100ff01ff00, 0x01010100ff010000, 0x0101010000ffff00,
+    0x010101000000ffff, 0x0101010000000000, 0x0101010000000101, 0x010101000001ff00,
+    0x0101010000010001, 0x0101010000010100, 0x010101000100ffff, 0x0101010001000001,
+    0x01010101ffffffff, 0x01010101ffffff01, 0x01010101ffff01ff, 0x01010101ffff0101,
+    0x01010101ff01ffff, 0x01010101ff01ff01, 0x01010101ff0101ff, 0x01010101ff010101,
+    0x010101010000ff00, 0x01010101000000ff, 0x0101010100000001, 0x0101010101ffffff,
+    0x0101010101ffff01, 0x0101010101ff01ff, 0x0101010101ff0101, 0x0101010101000000,
+    0x010101010101ffff, 0x010101010101ff01, 0x01010101010101ff, 0x0101010101010101,
+GGML_TABLE_END()
+#else
+GGML_TABLE_BEGIN(uint32_t, iq1s_grid_gpu, NGRID_IQ1S)
+    0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
+    0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
+    0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
+    0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
+    0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
+    0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
+    0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
+    0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
+    0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
+    0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
+    0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
+    0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
+    0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
+    0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
+    0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
+    0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
+    0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
+    0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
+    0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
+    0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
+    0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
+    0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
+    0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
+    0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
+    0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
+    0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
+    0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
+    0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
+    0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
+    0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
+    0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
+    0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
+    0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
+    0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
+    0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
+    0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
+    0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
+    0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
+    0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
+    0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
+    0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
+    0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
+    0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
+    0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
+    0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
+    0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
+    0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
+    0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
+    0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
+    0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
+    0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
+    0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
+    0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
+    0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
+    0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
+    0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
+    0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
+    0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
+    0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
+    0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
+    0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
+    0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
+    0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
+    0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
+    0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
+    0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
+    0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
+    0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
+    0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
+    0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
+    0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
+    0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
+    0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
+    0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
+    0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
+    0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
+    0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
+    0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
+    0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
+    0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
+    0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
+    0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
+    0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
+    0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
+    0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
+    0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
+    0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
+    0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
+    0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
+    0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
+    0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
+    0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
+    0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
+    0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
+    0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
+    0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
+    0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
+    0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
+    0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
+    0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
+    0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
+    0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
+    0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
+    0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
+    0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
+    0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
+    0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
+    0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
+    0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
+    0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
+    0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
+    0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
+    0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
+    0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
+    0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
+    0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
+    0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
+    0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
+    0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
+    0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
+    0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
+    0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
+    0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
+    0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
+    0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
+    0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
+    0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
+    0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
+    0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
+    0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
+    0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
+    0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
+    0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
+    0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
+    0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
+    0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
+    0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
+    0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
+    0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
+    0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
+    0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
+    0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
+    0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
+    0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
+    0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
+    0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
+    0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
+    0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
+    0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
+    0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
+    0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
+    0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
+    0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
+    0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
+    0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
+    0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
+    0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
+    0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
+    0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
+    0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
+    0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
+    0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
+    0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
+    0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
+    0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
+    0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
+    0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
+    0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
+    0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
+    0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
+    0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
+    0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
+    0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
+    0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
+    0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
+    0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
+    0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
+    0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
+    0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
+    0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
+    0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
+    0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
+    0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
+    0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
+    0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
+    0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
+    0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
+    0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
+    0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
+    0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
+    0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
+    0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
+    0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
+    0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
+    0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
+    0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
+    0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
+    0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
+    0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
+    0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
+    0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
+    0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
+    0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
+    0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
+    0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
+    0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
+    0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
+    0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
+    0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
+    0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
+    0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
+    0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
+    0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
+    0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
+    0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
+    0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
+    0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
+    0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
+    0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
+    0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
+    0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
+    0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
+    0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
+    0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
+    0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
+    0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
+    0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
+    0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
+    0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
+    0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
+    0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
+    0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
+    0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
+    0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
+    0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
+    0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
+    0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
+    0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
+    0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
+    0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
+    0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
+    0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
+    0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
+    0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
+    0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
+    0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
+    0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
+    0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
+    0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
+    0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
+    0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
+    0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
+    0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
+    0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
+    0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
+    0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
+GGML_TABLE_END()
+#endif
+
+#endif // GGML_COMMON_IMPL
+#endif // GGML_COMMON_IMPL
+#else
+// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
+#include "../ggml-common.h"
+#endif
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef GGML_METAL_IMPL
+#define GGML_METAL_IMPL
+
+// kernel argument structs
+//
+// - element counters (e.g. ne00) typically use int32_t to reduce register usage
+//   however, be careful from int overflows when using those in the kernel implementation
+//
+// - strides (e.g. nb00) use uint64_t
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  dim;
+} ggml_metal_kargs_concat;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+} ggml_metal_kargs_bin;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_repeat;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_cpy;
+
+typedef struct {
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+    bool     inplace;
+} ggml_metal_kargs_set;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  n_past;
+    int32_t  n_dims;
+    int32_t  n_ctx_orig;
+    float    freq_base;
+    float    freq_scale;
+    float    ext_factor;
+    float    attn_factor;
+    float    beta_fast;
+    float    beta_slow;
+} ggml_metal_kargs_rope;
+
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    uint64_t nb_12_1;
+    uint64_t nb_12_2;
+    uint64_t nb_12_3;
+    uint64_t nb31;
+    int32_t  ne1;
+    int32_t  ne2;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint16_t n_head_log2;
+    float    logit_softcap;
+} ggml_metal_kargs_flash_attn_ext;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mv;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+    int16_t  nsg;
+    int16_t  nxpsg;
+    int16_t  r1ptg;
+} ggml_metal_kargs_mul_mv_ext;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+} ggml_metal_kargs_mul_mm_id;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+    uint64_t nb1;
+} ggml_metal_kargs_mul_mv_id;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_norm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_rms_norm;
+
+#endif // GGML_METAL_IMPL
+
+#include 
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+//
+// cmd:
+//   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal/ggml-metal.metal
+//   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
+//
+#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16)
+#undef GGML_METAL_USE_BF16
+#endif
+
+#if defined(GGML_METAL_USE_BF16)
+typedef matrix bfloat4x4;
+#endif
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template 
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+    reg = (type4x4)(*src);
+}
+
+template 
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+    reg = (type4x4)(*src);
+}
+
+template 
+void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*(src + il));
+}
+
+#if defined(GGML_METAL_USE_BF16)
+template 
+void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
+    reg = (type4x4)(*src);
+}
+#endif
+
+template 
+void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
+        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i = 0; i < 2; i++) {
+        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
+        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
+    }
+}
+
+template 
+void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
+        reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i = 0; i < 2; i++) {
+        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
+        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
+    }
+}
+
+template 
+void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
+        reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = (il/4) ? 4 : 0;
+
+    const int gh_mv = (il/4) ? 12 : 0;
+    const int gh_bk = (il/4) ?  0 : 4;
+
+    for (int ii = 0; ii < 2; ii++) {
+        int i = 2*(il%4) + ii;
+
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[2*ii + 0] = d * x0 + md;
+        reg[2*ii + 1] = d * x1 + md;
+    }
+}
+
+template 
+void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
+        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = (il/4) ? 4 : 0;
+
+    const int gh_mv = (il/4) ? 12 : 0;
+    const int gh_bk = (il/4) ?  0 : 4;
+
+    for (int ii = 0; ii < 2; ii++) {
+        int i = 2*(il%4) + ii;
+
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[2*ii + 0] = d * x0 + m;
+        reg[2*ii + 1] = d * x1 + m;
+    }
+}
+
+template 
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const float d = xb->d;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 16; i++) {
+        reg_f[i/4][i%4] = (qs[i + 16*il] * d);
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const float d = xb->d;
+
+    for (int i = 0; i < 4; i++) {
+        reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
+    }
+}
+
+template 
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+    const float d = xb->d;
+    const float min = xb->dmin;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    float dl, ml;
+    uint8_t sc = xb->scales[il];
+
+    q = q + 32*(il/8) + 16*(il&1);
+    il = (il/2)%4;
+
+    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template 
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    device const uint8_t * h = (device const uint8_t *)xb->hmask;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+    q = q + 32 * (il/8) + 16 * (il&1);
+    h = h + 16 * (il&1);
+    uint8_t m = 1 << (il/2);
+    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+                                 ((il/4)>0 ? 12  : 3);
+    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+    const float ml = 4.f * dl;
+
+    il = (il/2) & 3;
+    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl *= coef;
+
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+    }
+}
+
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
+template 
+void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
+    device const uchar * q = xb->qs;
+
+    short is = (il/4) * 2;
+    q = q + (il/4) * 32 + 16 * (il&1);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
+
+    const ushort mask = il < 2 ? 0x0F : 0xF0;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template 
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q  = xb->qs;
+    device const uint8_t * qh = xb->qh;
+
+    short is = (il/4) * 2;
+    q  = q + 32 * (il/4) + 16 * (il&1);
+    qh = qh + 16 * (il&1);
+    uint8_t ul = 1 << (il/2);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d = il < 2 ? xb->d : xb->d / 16.f;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
+
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+    }
+}
+
+template 
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * ql = (device const uint8_t *)xb->ql;
+    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+    qh = qh + 32*(il/8) + 16*(il&1);
+    float sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
+
+    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const float ml = d_all * sc * 32.f;
+    const float dl = d_all * sc * coef;
+    for (int i = 0; i < 16; ++i) {
+        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
+    }
+}
+
+template 
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+}
+
+template 
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+}
+
+template 
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * q3 = xb->qs + 8*ib32;
+    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+}
+
+template 
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 8*ib32;
+    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+    }
+    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+    }
+}
+
+template 
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * signs = qs + QK_K/8;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+    }
+}
+
+template 
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    const float d = xb->d;
+    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint16_t * qh = xb->qh;
+    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+    const uint16_t h = qh[ib32] >> 6*il;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml;
+    }
+}
+
+template 
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const float d = scale.f16;
+
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
+    }
+}
+
+template 
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+template 
+void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
+    reg[0] = d * kvalues_iq4nl_f[q8[0]];
+    reg[1] = d * kvalues_iq4nl_f[q8[1]];
+    reg[2] = d * kvalues_iq4nl_f[q8[2]];
+    reg[3] = d * kvalues_iq4nl_f[q8[3]];
+}
+
+template 
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)xb->d * (ls - 32);
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+enum ggml_sort_order {
+    GGML_SORT_ORDER_ASC,
+    GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+kernel void kernel_sub(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+kernel void kernel_mul(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+kernel void kernel_div(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+template
+kernel void kernel_repeat(
+        constant ggml_metal_kargs_repeat & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    const int i03 = i3%args.ne03;
+    const int i02 = i2%args.ne02;
+    const int i01 = i1%args.ne01;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i00 = i0%args.ne00;
+        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
+    }
+}
+
+typedef decltype(kernel_repeat) kernel_repeat_t;
+
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
+
+kernel void kernel_sub_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant     float  & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+        device const float * src0,
+        device       float * dst,
+        constant     float & min,
+        constant     float & max,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_elu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
+}
+
+kernel void kernel_sqr(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sqrt(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sqrt(src0[tpig]);
+}
+
+kernel void kernel_sin(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = cos(src0[tpig]);
+}
+
+kernel void kernel_sum_rows(
+        device const float * src0,
+        device       float * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tpig[[thread_position_in_grid]]) {
+    int64_t i3 = tpig.z;
+    int64_t i2 = tpig.y;
+    int64_t i1 = tpig.x;
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float row_sum = 0;
+
+    for (int64_t i0 = 0; i0 < ne00; i0++) {
+        row_sum += src_row[i0];
+    }
+
+    dst_row[0] = row_sum;
+}
+
+template
+kernel void kernel_soft_max(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float lmax = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+
+    // find the max value in the block
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float lsum = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+        lsum += exp_psrc0;
+        pdst[i00] = exp_psrc0;
+    }
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        pdst[i00] *= inv_sum;
+    }
+}
+
+template
+kernel void kernel_soft_max_4(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+    float slope = 1.0f;
+
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float4 lmax4 = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    }
+
+    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+
+    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        pdst4[i00] *= inv_sum;
+    }
+}
+
+typedef decltype(kernel_soft_max)    kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max;
+template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+
+kernel void kernel_diag_mask_inf(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant       int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i02 = tpig[2];
+    const int64_t i01 = tpig[1];
+    const int64_t i00 = tpig[0];
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne01,
+        constant        int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+
+    const int64_t i = 2*tpig[0];
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int64_t i4 = 4*i;
+    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        dst[i+1][k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            dst[i][k] = -INFINITY;
+        }
+    }
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
+// TODO: optimize
+kernel void kernel_ssm_conv_f32(
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t ir = tgpig.x;
+    const int64_t i2 = tgpig.y;
+    const int64_t i3 = tgpig.z;
+
+    const int64_t nc  = ne10;
+  //const int64_t ncs = ne00;
+  //const int64_t nr  = ne01;
+  //const int64_t n_t = ne1;
+  //const int64_t n_s = ne2;
+
+    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
+    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
+    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
+
+    float sumf = 0.0f;
+
+    for (int64_t i0 = 0; i0 < nc; ++i0) {
+        sumf += s[i0] * c[i0];
+    }
+
+    x[0] = sumf;
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
+// TODO: optimize
+kernel void kernel_ssm_scan_f32(
+        device const void * src0,
+        device const void * src1,
+        device const void * src2,
+        device const void * src3,
+        device const void * src4,
+        device const void * src5,
+        device      float * dst,
+        constant  int64_t & d_state,
+        constant  int64_t & d_inner,
+        constant  int64_t & n_seq_tokens,
+        constant  int64_t & n_seqs,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant uint64_t & nb20,
+        constant uint64_t & nb21,
+        constant uint64_t & nb22,
+        constant uint64_t & nb30,
+        constant uint64_t & nb31,
+        constant uint64_t & nb40,
+        constant uint64_t & nb41,
+        constant uint64_t & nb42,
+        constant uint64_t & nb50,
+        constant uint64_t & nb51,
+        constant uint64_t & nb52,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t ir = tgpig.x;
+    const int64_t i3 = tgpig.y;
+
+    const int64_t nc  = d_state;
+  //const int64_t nr  = d_inner;
+    const int64_t n_t = n_seq_tokens;
+  //const int64_t n_s = n_seqs;
+
+    for (int64_t i2 = 0; i2 < n_t; ++i2) {
+        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
+        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
+        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
+        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
+        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
+        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
+
+        if (i2 > 0) {
+            s0 = s;
+        }
+
+        // i1 == 0
+        float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+        float x_dt = x[0] * dt_soft_plus;
+        float sumf = 0.0f;
+
+        for (int64_t i0 = 0; i0 < nc; ++i0) {
+            int64_t i = i0;
+            float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+            sumf += state * C[i0];
+            s[i] = state;
+        }
+
+        y[0] = sumf;
+    }
+}
+
+kernel void kernel_argmax(
+        device   const void * x,
+        device      int32_t * dst,
+        constant    int64_t & ncols,
+        constant   uint64_t & nb01,
+        threadgroup   float * shared_maxval [[threadgroup(0)]],
+        threadgroup int32_t * shared_argmax [[threadgroup(1)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
+
+    float   lmax = -INFINITY;
+    int32_t larg = -1;
+
+    for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
+        if (x_row[i00] > lmax) {
+            lmax = x_row[i00];
+            larg = i00;
+        }
+    }
+
+    // find the argmax value in the block
+    float max_val = simd_max(lmax);
+    int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            shared_maxval[tiisg] = -INFINITY;
+            shared_argmax[tiisg] = -1;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            shared_maxval[sgitg] = max_val;
+            shared_argmax[sgitg] = arg_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = shared_maxval[tiisg];
+        arg_val = shared_argmax[tiisg];
+
+        float max_val_reduced   = simd_max(max_val);
+        int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
+
+        dst[tgpig] = arg_val_reduced;
+
+        return;
+    }
+
+    dst[tgpig] = arg_val;
+}
+
+kernel void kernel_norm(
+        constant ggml_metal_kargs_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float4 sumf4(0.0f);
+
+    float sumf = 0.0f;
+
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf4 += x[i00];
+    }
+    sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float mean = sumf/args.ne00;
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+
+    sumf = 0.0f;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] - mean;
+        sumf += dot(y[i00], y[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float variance = sumf/args.ne00;
+
+    const float scale = 1.0f/sqrt(variance + args.eps);
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+kernel void kernel_rms_norm(
+        constant ggml_metal_kargs_rms_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float sumf = 0.0f;
+
+    // parallel sum
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float mean  = sumf/args.ne00;
+    const float scale = 1.0f/sqrt(mean + args.eps);
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
+kernel void kernel_group_norm(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int32_t & n_groups,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    const int64_t ne = ne00*ne01*ne02;
+    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+    int start = tgpig * gs;
+    int end   = start + gs;
+
+    start += tpitg;
+
+    if (end >= ne) {
+        end = ne;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += ntg) {
+        tmp += src0[j];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float mean = tmp / gs;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += ntg) {
+        float xi = src0[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float variance = tmp / gs;
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int j = start; j < end; j += ntg) {
+        dst[j] *= scale;
+    }
+}
+
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
+
+    for (int i = 0; i < 8; i += 2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+
+    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
+}
+
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
+}
+
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+
+    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
+}
+
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4        // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
+//Note: This is a template, but strictly speaking it only applies to
+//      quantizations where the block size is 32. It also does not
+//      guard against the number of rows not being divisible by
+//      N_DST, so this is another explicit assumption of the implementation.
+template
+void mul_vec_q_n_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    const int nb = args.ne00/QK4_0;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * nsg + sgitg) * nr;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q_type * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
+    }
+
+    float yl[16]; // src1 vector cache
+    float sumf[nr] = {0.f};
+
+    const short ix = (tiisg/2);
+    const short il = (tiisg%2)*8;
+
+    device const float * yb = y + ix*QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
+        float sumy[2] = { 0.f, 0.f };
+
+#pragma unroll
+        for (int i = 0; i < 8; i += 2) {
+            sumy[0]  += yb[i +  0] + yb[i +  1];
+            yl[i + 0] = yb[i +  0];
+            yl[i + 1] = yb[i +  1]/256.f;
+
+            sumy[1]  += yb[i + 16] + yb[i + 17];
+            yl[i + 8] = yb[i + 16]/16.f;
+            yl[i + 9] = yb[i + 17]/4096.f;
+        }
+
+#pragma unroll
+        for (int row = 0; row < nr; row++) {
+            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
+        }
+
+        yb += QK4_0 * 16;
+    }
+
+    device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+
+        if (tiisg == 0 && first_row + row < args.ne01) {
+            dst_f32[first_row + row] = tot;
+        }
+    }
+}
+
+kernel void kernel_mul_mv_q4_0_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q4_1_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q5_0_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+#define NB_Q8_0 8
+
+template
+void kernel_mul_mv_q8_0_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    const int nr  = N_DST;
+    const int nsg = N_SIMDGROUP;
+    const int nw  = N_SIMDWIDTH;
+
+    const int nb = args.ne00/QK8_0;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0*nsg + sgitg)*nr;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q8_0 * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
+    }
+
+    float yl[NB_Q8_0];
+    float sumf[nr] = { 0.f };
+
+    const short ix = tiisg/4;
+    const short il = tiisg%4;
+
+    device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
+
+    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+    for (int ib = ix; ib < nb; ib += nw/4) {
+        for (short i = 0; i < NB_Q8_0; ++i) {
+            yl[i] = yb[i];
+        }
+
+        for (int row = 0; row < nr; row++) {
+            device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
+            float sumq = 0.f;
+            for (short iq = 0; iq < NB_Q8_0; ++iq) {
+                sumq += qs[iq] * yl[iq];
+            }
+            sumf[row] += sumq*ax[row][ib].d;
+        }
+
+        yb += nw*NB_Q8_0;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+
+        if (tiisg == 0 && first_row + row < args.ne01) {
+            dst_f32[first_row + row] = tot;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+// mat-vec kernel processing in chunks of float4
+// chpb - chunks per quantization block
+template
+void kernel_mul_mv_ext_q4_f32_impl(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short chpt = 4; // chunks per thread
+
+  //const short nxpsg = (32);
+    const short nypsg = (32/nxpsg);
+
+    const short tx = tiisg%nxpsg;
+    const short ty = tiisg/nxpsg;
+
+    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i11 = tgpig.y*r1ptg;
+    const int i1m = tgpig.z;
+
+    const int i12 = i1m%args.ne12;
+    const int i13 = i1m/args.ne12;
+
+    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+    device const float4 * y4[r1ptg];
+
+    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+        y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
+    }
+
+    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+    short cch = tx%chpb; // current chunk index
+
+    for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
+        float4 lx[chpt];
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+            deq_t4(xq, cch, lx[ch]);
+
+            cch += nxpsg;
+            if (cch >= chpb) {
+                xq  += cch/chpb;
+                cch %= chpb;
+            }
+        }
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+                sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
+
+            }
+        }
+
+#pragma unroll(r1ptg)
+        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+            y4[ir1] += chpt*nxpsg;
+        }
+    }
+
+    // reduce only the threads in each row
+    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+        if (nxpsg >= 32) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+        }
+        if (nxpsg >= 16) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
+        }
+        if (nxpsg >= 8) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
+        }
+        if (nxpsg >= 4) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
+        }
+        if (nxpsg >= 2) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
+        }
+
+        //sumf[ir1] = simd_sum(sumf[ir1]);
+    }
+
+    if (tx == 0) {
+        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+            if (i01 < args.ne01) {
+                dst_f32[i01] = sumf[ir1];
+            }
+        }
+    }
+}
+
+// mat-vec kernel processing in chunks of float4x4
+template
+void kernel_mul_mv_ext_q4x4_f32_impl(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short chpt = 1;
+
+  //const short nxpsg = (32);
+    const short nypsg = (32/nxpsg);
+
+    const short tx = tiisg%nxpsg;
+    const short ty = tiisg/nxpsg;
+
+    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i11 = tgpig.y*r1ptg;
+    const int i1m = tgpig.z;
+
+    const int i12 = i1m%args.ne12;
+    const int i13 = i1m/args.ne12;
+
+    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+    device const float4x4 * y4x4[r1ptg];
+
+    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+        y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
+    }
+
+    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+    short cch = tx%chpb;
+
+    for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
+        float4x4 lx[chpt];
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+            deq_t4x4(xq, cch, lx[ch]);
+
+            cch += nxpsg;
+            if (cch >= chpb) {
+                xq  += cch/chpb;
+                cch %= chpb;
+            }
+        }
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+                sumf[ir1] +=
+                    dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
+                    dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
+                    dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
+                    dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
+
+            }
+        }
+
+#pragma unroll(r1ptg)
+        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+            y4x4[ir1] += chpt*nxpsg;
+        }
+    }
+
+    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+        if (nxpsg >= 32) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+        }
+        if (nxpsg >= 16) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
+        }
+        if (nxpsg >= 8) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
+        }
+        if (nxpsg >= 4) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
+        }
+        if (nxpsg >= 2) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
+        }
+
+        //sumf[ir1] = simd_sum(sumf[ir1]);
+    }
+
+    if (tx == 0) {
+        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+            if (i01 < args.ne01) {
+                dst_f32[i01] = sumf[ir1];
+            }
+        }
+    }
+}
+
+// dispatchers needed for compile-time nxpsg
+// epb - elements per quantization block
+template
+kernel void kernel_mul_mv_ext_q4_f32_disp(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    switch (args.nxpsg) {
+        case 4:  kernel_mul_mv_ext_q4_f32_impl<4,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 8:  kernel_mul_mv_ext_q4_f32_impl<8,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+    }
+}
+
+template
+kernel void kernel_mul_mv_ext_q4x4_f32_disp(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    switch (args.nxpsg) {
+        case 4:  kernel_mul_mv_ext_q4x4_f32_impl<4,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 8:  kernel_mul_mv_ext_q4x4_f32_impl<8,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+    }
+}
+
+typedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
+typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>)    mul_mv_ext_q4x4_f32_t;
+
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4,        4,  dequantize_f16_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0,   32, dequantize_q4_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1,   32, dequantize_q4_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0,   32, dequantize_q5_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1,   32, dequantize_q5_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
+
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
+
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
+
+#define N_MV_T_T 4
+
+template
+void kernel_mul_mv_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg) {
+    const int r0 = tgpig.x;
+    const int rb = tgpig.y*N_MV_T_T;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+    device const T0 * x = (device const T0 *) (src0 + offset0);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+    if (args.ne00 < 128) {
+        for (int row = 0; row < N_MV_T_T; ++row) {
+            int r1 = rb + row;
+            if (r1 >= args.ne11) {
+                break;
+            }
+
+            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+            device const T1 * y = (device const T1 *) (src1 + offset1);
+
+            float sumf = 0;
+            for (int i = tiisg; i < args.ne00; i += 32) {
+                sumf += (T0) x[i] * (T1) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const T04 * x4 = (device const T04 *) x;
+        for (int row = 0; row < N_MV_T_T; ++row) {
+            int r1 = rb + row;
+            if (r1 >= args.ne11) {
+                break;
+            }
+
+            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+            device const T1  * y  = (device const T1  *) (src1 + offset1);
+            device const T14 * y4 = (device const T14 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < args.ne00/4; i += 32) {
+                sumf += dot((float4) x4[i], (float4) y4[i]);
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+template
+kernel void kernel_mul_mv(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_impl(
+        args,
+        src0,
+        src1,
+        dst,
+        tgpig,
+        tiisg);
+}
+
+typedef decltype(kernel_mul_mv) mul_mv_t;
+
+template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv;
+#endif
+
+template
+kernel void kernel_mul_mv_1row(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const T     * x = (device const T     *) (src0 + offset0);
+    device const float * y = (device const float *) (src1 + offset1);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    float sumf = 0;
+    if (args.ne00 < 128) {
+        for (int i = tiisg; i < args.ne00; i += 32) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst_f32[r0] = all_sum;
+        }
+    } else {
+        device const T4     * x4 = (device const T4     *) x;
+        device const float4 * y4 = (device const float4 *) y;
+
+        for (int i = tiisg; i < args.ne00/4; i += 32) {
+            sumf += dot((float4) x4[i], y4[i]);
+        }
+
+        float all_sum = simd_sum(sumf);
+
+        if (tiisg == 0) {
+            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+            dst_f32[r0] = all_sum;
+        }
+    }
+}
+
+typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row;
+#endif
+
+// Assumes row size (ne00) is a multiple of 4
+template
+kernel void kernel_mul_mv_l4(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+
+    const int nrows = args.ne11;
+    const int r0 = tgpig.x;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+    device const T4 * x4 = (device const T4 *) (src0 + offset0);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+        device const float4 * y4 = (device const float4 *) (src1 + offset1);
+
+        float sumf = 0;
+        for (int i = tiisg; i < args.ne00/4; i += 32) {
+            sumf += dot((float4) x4[i], y4[i]);
+        }
+
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+        }
+    }
+}
+
+typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4;
+#endif
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
+    thread float * cos_theta, thread float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    *cos_theta = cos(theta) * mscale;
+    *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
+}
+
+template
+kernel void kernel_rope_norm(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+template
+kernel void kernel_rope_neox(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[args.n_dims/2];
+
+            dst_data[0]             = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+typedef decltype(kernel_rope_norm) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox;
+
+typedef void (im2col_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template 
+kernel void kernel_im2col(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+//    const int64_t IC = tgpg[0];
+    const int64_t OH = tgpg[1];
+    const int64_t OW = tgpg[2];
+
+//    const int64_t N  = ntg[0];
+    const int64_t KH = ntg[1];
+    const int64_t KW = ntg[2];
+
+    const int64_t in  = tpitg[0];
+    const int64_t ikh = tpitg[1];
+    const int64_t ikw = tpitg[2];
+
+    const int64_t iic = tgpig[0];
+    const int64_t ioh = tgpig[1];
+    const int64_t iow = tgpig[2];
+
+    const int64_t iiw = iow*s0 + ikw*d0 - p0;
+    const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+        pdst[offset_dst] = x[offset_src];
+    }
+}
+
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
+
+typedef void (im2col_ext_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template 
+kernel void kernel_im2col_ext(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
+    const int64_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+
+    const int64_t d = tgpig[0] / CHW;
+    const int64_t chw = tgpig[0] % CHW;
+    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
+    const int64_t HW = tgpig[0] % KHW;
+
+    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+    if (tpitg_0 >= N) {
+        return;
+    }
+
+    const int64_t tpitg_1 = HW / KW;
+    const int64_t tpitg_2 = HW % KW;
+
+    const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
+    const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+
+    const int64_t offset_dst =
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
+        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext;
+template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext;
+
+typedef void (conv_transpose_1d_t)(
+        device const float * src0,
+        device const float * src1,
+        device        char * dst,
+        constant   int32_t & IC,
+        constant   int32_t & IL,
+        constant   int32_t & K,
+        constant   int32_t & s0,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3    tgpg[[threadgroups_per_grid]]);
+
+template 
+kernel void kernel_conv_transpose_1d(
+        device const     T * src0,
+        device const float * src1,
+        device        char * dst,
+        constant   int32_t & IC,
+        constant   int32_t & IL,
+        constant   int32_t & K,
+        constant   int32_t & s0,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3   tgpg[[threadgroups_per_grid]]) {
+
+    float v = 0.0f;
+
+    for (int64_t c = 0; c < IC; c++) {
+        const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
+        const int32_t input_offset = c * IL;
+
+        for (int64_t i = 0; i < IL; i++) {
+            if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
+                v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+            }
+        }
+    }
+
+    device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+
+    dst_ptr[0] = v;
+}
+
+template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
+kernel void kernel_conv_transpose_1d(
+    device const float * src0,
+    device const float * src1,
+    device        char * dst,
+    constant   int32_t & IC,
+    constant   int32_t & IL,
+    constant   int32_t & K,
+    constant   int32_t & s0,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    uint3    tgpg[[threadgroups_per_grid]]);
+
+template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
+kernel void kernel_conv_transpose_1d(
+    device const half  * src0,
+    device const float * src1,
+    device        char * dst,
+    constant   int32_t & IC,
+    constant   int32_t & IL,
+    constant   int32_t & K,
+    constant   int32_t & s0,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    uint3    tgpg[[threadgroups_per_grid]]);
+
+kernel void kernel_upscale_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant     float & sf0,
+    constant     float & sf1,
+    constant     float & sf2,
+    constant     float & sf3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3/sf3;
+    const int64_t i02 = i2/sf2;
+    const int64_t i01 = i1/sf1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int64_t i00 = i0/sf0;
+
+        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+
+        dst_ptr[0] = src0_ptr[0];
+    }
+}
+
+kernel void kernel_pad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            } else {
+                dst_ptr[i0] = 0.0f;
+            }
+        }
+
+        return;
+    }
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = 0.0f;
+    }
+}
+
+kernel void kernel_pad_reflect_1d_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant   int64_t & ne0,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant   int32_t & p0,
+    constant   int32_t & p1,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3  tgpg[[threadgroups_per_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < p0) {
+                dst_ptr[i0] = src0_ptr[p0 - i0];
+            } else if (i0 < ne0 - p1) {
+                dst_ptr[i0] = src0_ptr[i0 - p0];
+            } else {
+                dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+            }
+        }
+    }
+}
+
+kernel void kernel_unpad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            }
+        }
+
+        return;
+    }
+}
+
+kernel void kernel_arange_f32(
+    device        char * dst,
+    constant   int64_t & ne0,
+    constant   float   & start,
+    constant   float   & step,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    device float * dst_ptr = (device float *) dst;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = start + step * i0;
+    }
+}
+
+kernel void kernel_timestep_embedding_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant  uint64_t & nb1,
+    constant  int      & dim,
+    constant  int      & max_period,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    int i = tgpig.x;
+    device float * embed_data = (device float *)(dst +  i*nb1);
+
+    int half_ = dim / 2;
+    for (int j = tpitg.x; j < half_; j += ntg.x) {
+        float timestep = ((device float *)src0)[i];
+        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float arg = timestep * freq;
+        embed_data[j        ] = cos(arg);
+        embed_data[j + half_] = sin(arg);
+    }
+
+    if (dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[dim] = 0.f;
+    }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+        device const float  * x,
+        device     int32_t  * dst,
+        constant   int64_t  & ncols,
+        constant   int64_t  & ncols_pad,
+        threadgroup int32_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template
+kernel void kernel_argsort_f32_i32(
+        device const float   * x,
+        device       int32_t * dst,
+        constant     int64_t & ncols,
+        constant     int64_t & ncols_pad,
+        threadgroup int32_t  * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+    // bitonic sort
+    int col = tpitg[0];
+    int row = tgpig[1];
+
+    if (col >= ncols_pad) return;
+
+    device const float   * x_row   = x + row * ncols;
+    threadgroup int32_t  * dst_row = shared_values;
+
+    // initialize indices
+    dst_row[col] = col;
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+        }
+    }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
+
+kernel void kernel_leaky_relu_f32(
+        device const float * src0,
+        device       float * dst,
+        constant     float & slope,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<
+    typename q_t,     // query types in shared memory
+    typename q4_t,
+    typename q8x8_t,
+    typename k_t,     // key types in shared memory
+    typename k4x4_t,
+    typename k8x8_t,
+    typename v_t,     // value types in shared memory
+    typename v4x4_t,
+    typename v8x8_t,
+    typename qk_t,    // Q*K types
+    typename qk8x8_t,
+    typename s_t,     // soft-max types
+    typename s8x8_t,
+    typename o_t,     // attention accumulation types
+    typename o4_t,
+    typename o8x8_t,
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // key type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short D,         // head size
+    short Q  = 8,    // queries per threadgroup
+    short KV = 8,    // key/value processed per each simdgroup
+    short C  = 32>   // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3   ntg[[threads_per_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const int iq3 = tgpig[2];
+    const int iq2 = tgpig[1];
+    const int iq1 = tgpig[0]*Q;
+
+    const short D4  = D/4;
+    const short D8  = D/8;
+    const short D16 = D/16;
+    const short NW  = N_SIMDWIDTH;
+    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
+
+    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
+    const short T  = D + 2*TS; // shared memory size per query in (half)
+
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*D); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*D); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*D); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*D); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+
+    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+
+    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    o8x8_t lo[D8];
+
+    // load heads from Q to shared memory
+    for (short j = sgitg; j < Q; j += nsg) {
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+
+        for (short i = tiisg; i < D4; i += NW) {
+            if (iq1 + j < args.ne01) {
+                sq4[j*D4 + i] = (q4_t) q4[i];
+            } else {
+                sq4[j*D4 + i] = (q4_t) 0.0f;
+            }
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D8; ++i) {
+        lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f);
+    }
+
+    // zero out shared memory SH
+    for (short j = 0; j < Q; ++j) {
+        for (short i = tiisg; i < SH; i += NW) {
+            ss[j*TS + i] = 0.0f;
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        half S[Q] = { [0 ... Q-1] = 0.0f };
+        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+
+        // thread indices inside the simdgroup
+        // TODO: see if we can utilize quad-group functions for better performance
+        //       https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
+        const short tx = tiisg%4;
+        const short ty = tiisg/4;
+
+        // broadcast kv
+        //const short rk2 = args.ne02/args.ne12;
+        //const short rk3 = args.ne03/args.ne13;
+
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+        // load the queries from shared memory into local memory
+        q8x8_t mq[D8];
+
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_load(mq[i], sq + i*8, D);
+        }
+
+        const bool has_mask = mask != q;
+
+        half slope = 1.0f;
+
+        // ALiBi
+        if (args.max_bias > 0.0f) {
+            const short h = iq2;
+
+            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+            slope = pow(base, exph);
+        }
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= args.ne11) {
+                break;
+            }
+
+            if (has_mask) {
+                // used to detect blocks full of -INF
+                half smax = -INFINITY;
+
+                // load the mask in shared memory
+                #pragma unroll(Q)
+                for (short j = 0; j < Q; ++j) {
+                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
+
+                    const half m = pm[ic + tiisg];
+
+                    ss[j*TS + C + tiisg] = m;
+                    smax = max(smax, m);
+                }
+
+                smax = simd_max(smax);
+
+                if (smax == -INFINITY) {
+                    continue;
+                }
+            }
+
+            // Q*K^T
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
+
+                    // this is compile-time check, so it does not have runtime overhead
+                    if (is_same::value) {
+                        // we can read directly from global memory
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                        #pragma unroll(D8)
+                        for (short i = 0; i < D8; ++i) {
+                            k8x8_t mk;
+                            simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+
+                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                        }
+                    } else {
+                        for (short ii = 0; ii < D16; ii += 4) {
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                            if (D16%4 == 0) {
+                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+                                {
+                                    k4x4_t tmp;
+                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                    sk4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                #pragma unroll(4)
+                                for (short k = 0; k < 4; ++k) {
+                                    k8x8_t mk;
+
+                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                }
+                            } else {
+                                if (ii + tx < D16) {
+                                    k4x4_t tmp;
+                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                    sk4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                    k8x8_t mk;
+
+                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                }
+                            }
+                        }
+                    }
+
+                    // cast qk_t -> s_t
+                    //s8x8_t mqks(1.0f);
+                    //simdgroup_multiply(mqks, mqk, mqks);
+                    //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
+
+                    simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
+                }
+            }
+
+            // online softmax
+            {
+                for (ushort j = 0; j < Q; ++j) {
+                    const half m = M[j];
+
+                    // scale and apply the logitcap / mask
+                    half s = ss[j*TS + tiisg]*args.scale;
+
+                    if (args.logit_softcap != 0.0f) {
+                        s = args.logit_softcap*precise::tanh(s);
+                    }
+
+                    // mqk = mqk + mask*slope
+                    s += slope*ss[j*TS + C + tiisg];
+
+                    M[j] = simd_max(max(M[j], s));
+
+                    const half ms = exp(m - M[j]);
+                    const half vs = exp(s - M[j]);
+
+                    S[j] = S[j]*ms + simd_sum(vs);
+
+                    // the P matrix from the paper (Q rows, C columns)
+                    ss[j*TS + tiisg] = vs;
+
+                    // create a QxQ diagonal matrix for rescaling the output
+                    if (tiisg == j) {
+                        ss[j*TS + 2*C + j] = ms;
+                    }
+                }
+            }
+
+            // O = diag(ms)*O
+            {
+                s8x8_t mm;
+                simdgroup_load(mm, ss + 2*C, TS, 0, false);
+
+                #pragma unroll(D8)
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_multiply(lo[i], mm, lo[i]);
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    s8x8_t ms;
+                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
+
+                    if (is_same::value) {
+                        // we can read directly from global memory
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                        #pragma unroll(D8)
+                        for (short i = 0; i < D8; ++i) {
+                            v8x8_t mv;
+                            simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+
+                            simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+                        }
+                    } else {
+                        for (short ii = 0; ii < D16; ii += 4) {
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                            if (D16%4 == 0) {
+                                // no need for bound checks
+                                {
+                                    v4x4_t tmp;
+                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+                                    sv4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                #pragma unroll(4)
+                                for (short k = 0; k < 4; ++k) {
+                                    v8x8_t mv;
+
+                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                }
+                            } else {
+                                if (ii + tx < D16) {
+                                    v4x4_t tmp;
+                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+                                    sv4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                    v8x8_t mv;
+
+                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        for (short j = 0; j < Q; ++j) {
+            if (tiisg == 0) {
+                ss[j*TS + 0] = S[j];
+                ss[j*TS + 1] = M[j];
+            }
+        }
+    }
+
+    // reduce the warps sequentially
+    for (ushort sg = 1; sg < nsg; ++sg) {
+        half S = { 0.0f };
+        half M = { -__FLT16_MAX__/2 };
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // each simdgroup stores its output to shared memory, reusing sq
+        if (sgitg == sg) {
+            for (short i = 0; i < D8; ++i) {
+                simdgroup_store(lo[i], so + i*8, D, 0, false);
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // the first simdgroup accumulates the results from the other simdgroups
+        if (sgitg == 0) {
+            for (short j = 0; j < Q; ++j) {
+                const half S0 = ss[j*TS +         0];
+                const half S1 = ss[j*TS + sg*SH + 0];
+
+                const half M0 = ss[j*TS +         1];
+                const half M1 = ss[j*TS + sg*SH + 1];
+
+                M = max(M0, M1);
+
+                const half ms0 = exp(M0 - M);
+                const half ms1 = exp(M1 - M);
+
+                S = S0*ms0 + S1*ms1;
+
+                if (tiisg == 0) {
+                    ss[j*TS + 0] = S;
+                    ss[j*TS + 1] = M;
+
+                    ss[j*TS + 2*C + j        ] = ms0;
+                    ss[j*TS + 2*C + j + sg*SH] = ms1;
+                }
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            {
+                s8x8_t ms0;
+                s8x8_t ms1;
+
+                simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
+                simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
+
+                #pragma unroll(D8)
+                for (short i = 0; i < D8; ++i) {
+                    o8x8_t t;
+
+                    simdgroup_load    (t, so + i*8, D, 0, false);
+                    simdgroup_multiply(t, ms1, t);
+
+                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+                }
+            }
+        }
+    }
+
+    // store result to shared memory (reuse sq)
+    if (sgitg == 0) {
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_store(lo[i], so + i*8, D, 0, false);
+        }
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
+            const float S = ss[j*TS + 0];
+
+            for (short i = tiisg; i < D4; i += NW) {
+                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+            }
+        }
+    }
+}
+
+// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
+//       template to be able to explore different combinations
+//
+#define FA_TYPES \
+    half,  half4,   simdgroup_half8x8,  \
+    half,  half4x4, simdgroup_half8x8,  \
+    half,  half4x4, simdgroup_half8x8,  \
+    float,          simdgroup_float8x8, \
+    float,          simdgroup_float8x8, \
+    half,  half4,   simdgroup_half8x8
+
+typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+#endif
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+#undef FA_TYPES
+
+template<
+    typename q4_t,    // query types in shared memory
+    typename q4x4_t,
+    typename k4x4_t,  // key types in shared memory
+    typename v4x4_t,  // value types in shared memory
+    typename qk_t,    // Q*K types
+    typename s_t,     // soft-max types
+    typename s4_t,
+    typename s4x4_t,
+    typename o4x4_t,  // attention accumulation types
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // key type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short D,         // head size
+    short Q  = 1,    // queries per threadgroup
+    short C  = 32>   // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3   ntg[[threads_per_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const int iq3 = tgpig[2];
+    const int iq2 = tgpig[1];
+    const int iq1 = tgpig[0];
+
+    const short D4  = D/4;
+    const short D16 = D/16;
+    const short NW  = N_SIMDWIDTH;
+    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
+    const short SH  = 2*C;  // shared memory per simdgroup
+
+    const short T = D + nsg*SH; // shared memory size per query in (half)
+
+  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shmem_f16 +                0*D); // holds the query data
+    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shmem_f16 +                0*D); // same as above but in q4_t
+    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 +                0*D); // same as above but in q4x4_t
+    threadgroup s_t    * ss    = (threadgroup s_t    *) (shmem_f16 + sgitg*SH     + Q*D); // scratch buffer for attention
+    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH     + Q*D); // same as above but in s4_t
+    threadgroup half   * sm    = (threadgroup half   *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
+    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D      + Q*T); // scratch buffer for the results
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    o4x4_t lo[D16/NL];
+
+    // load heads from Q to shared memory
+    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+
+    for (short i = tiisg; i < D4; i += NW) {
+        if (iq1 < args.ne01) {
+            sq4[i] = (q4_t) q4[i];
+        } else {
+            sq4[i] = (q4_t) 0.0f;
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D16/NL; ++i) {
+        lo[i] = (o4x4_t) 0.0f;
+    }
+
+    // zero out shared memory SH
+    for (short i = tiisg; i < SH/4; i += NW) {
+        ss4[i] = (s4_t) 0.0f;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        half S = 0.0f;
+        half M = -__FLT16_MAX__/2;
+
+        // thread indices inside the simdgroup
+        const short tx = tiisg%NL;
+        const short ty = tiisg/NL;
+
+        // broadcast kv
+        //const short rk2 = args.ne02/args.ne12;
+        //const short rk3 = args.ne03/args.ne13;
+
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+        // load the queries from shared memory into local memory
+        q4x4_t mq[D16/NL];
+
+        #pragma unroll(D16/NL)
+        for (short ii = 0; ii < D16; ii += NL) {
+            mq[ii/NL] = sq4x4[ii + tx];
+        }
+
+        const bool has_mask = mask != q;
+
+        // pointer to the mask
+        device const half * pm = (device const half *) (mask + iq1*args.nb31);
+
+        half slope = 1.0f;
+
+        // ALiBi
+        if (args.max_bias > 0.0f) {
+            const short h = iq2;
+
+            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+            slope = pow(base, exph);
+        }
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= args.ne11) {
+                break;
+            }
+
+            if (has_mask) {
+                sm[tiisg] = pm[ic + tiisg];
+            }
+
+            // Q*K^T
+            {
+                // each simdgroup processes 1 query and 4 (NW/NL) keys
+                for (short cc = 0; cc < C/4; ++cc) {
+                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
+
+                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                    #pragma unroll(D16/NL)
+                    for (short ii = 0; ii < D16; ii += NL) {
+                        const short i = ii + tx;
+
+                        k4x4_t mk;
+                        deq_k(pk + i/nl_k, i%nl_k, mk);
+
+                        // note: this is less precise than the version below
+                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
+                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
+                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
+                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
+
+                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
+                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
+                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
+                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
+                    }
+
+                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+
+                    // simdgroup reduce
+                    // [ 0 ..  7] -> [ 0]
+                    // [ 8 .. 15] -> [ 8]
+                    // [16 .. 23] -> [16]
+                    // [24 .. 31] -> [24]
+                  //mqk += simd_shuffle_down(mqk, 16);
+                  //mqk += simd_shuffle_down(mqk,  8);
+                    mqk += simd_shuffle_down(mqk,  4);
+                    mqk += simd_shuffle_down(mqk,  2);
+                    mqk += simd_shuffle_down(mqk,  1);
+
+                    // mqk = mqk*scale + mask*slope
+                    if (tx == 0) {
+                        mqk *= args.scale;
+
+                        if (args.logit_softcap != 0.0f) {
+                            mqk = args.logit_softcap*precise::tanh(mqk);
+                        }
+
+                        mqk += sm[4*cc + ty]*slope;
+
+                        ss[4*cc + ty] = mqk;
+                    }
+                }
+            }
+
+            simdgroup_barrier(mem_flags::mem_threadgroup);
+
+            // online softmax
+            {
+                const half m = M;
+                const half s = ss[tiisg];
+
+                M = simd_max(max(M, s));
+
+                const half ms = exp(m - M);
+                const half vs = exp(s - M);
+
+                S = S*ms + simd_sum(vs);
+
+                // the P matrix from the paper (Q rows, C columns)
+                ss[tiisg] = vs;
+
+                // O = diag(ms)*O
+                #pragma unroll(D16/NL)
+                for (short ii = 0; ii < D16; ii += NL) {
+                    lo[ii/NL] *= ms;
+                }
+            }
+
+            simdgroup_barrier(mem_flags::mem_threadgroup);
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/4; ++cc) {
+                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                    const s4x4_t ms(ss[4*cc + ty]);
+
+                    #pragma unroll(D16/NL)
+                    for (short ii = 0; ii < D16; ii += NL) {
+                        const short i = ii + tx;
+
+                        v4x4_t mv;
+                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
+
+                        lo[ii/NL] += mv*ms;
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        if (tiisg == 0) {
+            ss[0] = (s_t) S;
+            ss[1] = (s_t) M;
+        }
+    }
+
+    // simdgroup reduce
+    // [ 0,  8, 16, 24] -> [ 0]
+    // [ 1,  9, 17, 25] -> [ 1]
+    // [ 2, 10, 18, 26] -> [ 2]
+    // [ 3, 11, 19, 27] -> [ 3]
+    // [ 4, 12, 20, 28] -> [ 4]
+    // [ 5, 13, 21, 29] -> [ 5]
+    // [ 6, 14, 22, 30] -> [ 6]
+    // [ 7, 15, 23, 31] -> [ 7]
+    for (short ii = 0; ii < D16; ii += NL) {
+        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+
+        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+
+        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+
+        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // store results to shared memory
+    for (short i = tiisg; i < D16; i += NL) {
+        sr4x4[i] = lo[i/NL];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // parallel reduce
+    for (short r = nsg/2; r > 0; r >>= 1) {
+        if (sgitg < r) {
+            const half S0 = ss[       0];
+            const half S1 = ss[r*SH + 0];
+
+            const half M0 = ss[       1];
+            const half M1 = ss[r*SH + 1];
+
+            const half M = max(M0, M1);
+
+            const half ms0 = exp(M0 - M);
+            const half ms1 = exp(M1 - M);
+
+            const half S = S0*ms0 + S1*ms1;
+
+            if (tiisg == 0) {
+                ss[0] = S;
+                ss[1] = M;
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            for (short i = tiisg; i < D16; i += NW) {
+                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    device float4x4 * dst44 = (device float4x4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        const float S = ss[0];
+
+        for (short i = tiisg; i < D16; i += NW) {
+            dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+        }
+    }
+}
+
+// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
+//       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
+//
+#define FA_TYPES \
+           half4,  half4x4, \
+                   half4x4, \
+                   half4x4, \
+    float,                  \
+    half,  half4,  half4x4, \
+                   half4x4
+
+typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+#undef FA_TYPES
+
+template
+kernel void kernel_set(
+    constant ggml_metal_kargs_set & args,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    ushort3 tpitg[[thread_position_in_threadgroup]],
+    ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i13 = tgpig[2];
+    const int i12 = tgpig[1];
+    const int i11 = tgpig[0];
+
+    const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
+
+    const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
+    const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
+    const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
+
+    device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
+
+    for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
+        device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
+        dst_data[i10] = (T) src[0];
+    }
+}
+
+typedef decltype(kernel_set) kernel_set_t;
+
+template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set;
+template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set;
+
+template
+kernel void kernel_cpy(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+        dst_data[i00] = (T1) src[0];
+    }
+}
+
+typedef decltype(kernel_cpy) kernel_cpy_t;
+
+template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy;
+#endif
+template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy;
+#endif
+
+kernel void kernel_cpy_f32_q8_0(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
+
+    device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = src[j];
+            amax = MAX(amax, fabs(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK8_0].d = d;
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = src[j]*id;
+
+            dst_data[i00/QK8_0].qs[j] = round(x0);
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
+
+    device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_0].d = d;
+
+        for (int j = 0; j < QK4_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK4_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            dst_data[i00/QK4_0].qs[j]  = xi0;
+            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
+
+    device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < QK4_1; j++) {
+            const float v = src[j];
+            if (min > v) min = v;
+            if (max < v) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_1].d = d;
+        dst_data[i00/QK4_1].m = min;
+
+        for (int j = 0; j < QK4_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            dst_data[i00/QK4_1].qs[j]  = xi0;
+            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_0(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
+
+    device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK5_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_0].d = d;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK5_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_0].qh[j] = qh8[j];
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
+
+    device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float max = src[0];
+        float min = src[0];
+
+        for (int j = 1; j < QK5_1; j++) {
+            const float v = src[j];
+            min = v < min ? v : min;
+            max = v > max ? v : max;
+        }
+
+        const float d = (max - min) / 31;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_1].d = d;
+        dst_data[i00/QK5_1].m = min;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_1].qh[j] = qh8[j];
+        }
+    }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+kernel void kernel_cpy_f32_iq4_nl(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
+
+    device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / kvalues_iq4nl_f[0];
+        const float id = d ? 1.0f/d : 0.0f;
+
+        float sumqx = 0, sumq2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            const float x0 = src[0        + j]*id;
+            const float x1 = src[QK4_NL/2 + j]*id;
+
+            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+            const float v0 = kvalues_iq4nl_f[xi0];
+            const float v1 = kvalues_iq4nl_f[xi1];
+            const float w0 = src[0        + j]*src[0        + j];
+            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+            sumq2 += w0*v0*v0 + w1*v1*v1;
+
+        }
+
+        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+    }
+}
+
+kernel void kernel_concat(
+    constant ggml_metal_kargs_concat & args,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    ushort3 tpitg[[thread_position_in_threadgroup]],
+    ushort3   ntg[[threads_per_threadgroup]]) {
+
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    int o[4] = {0, 0, 0, 0};
+    o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
+
+    device const float * x;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+            x = (device const float *)(src0 + (i3       )*args.nb03 + (i2       )*args.nb02 + (i1       )*args.nb01 + (i0       )*args.nb00);
+        } else {
+            x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
+        }
+
+        device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+        *y = *x;
+    }
+}
+
+template
+void kernel_mul_mv_q2_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int iq = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+    const int is = (8*ir)/16;// 0 or 1
+
+    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+        }
+
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+            float dall = dh[0];
+            float dmin = dh[1] * 1.f/16.f;
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+            qs += args.nb01/2;
+            sc += args.nb01;
+            dh += args.nb01/2;
+        }
+
+        y4 += 4 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q3_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
+
+    float yl[32];
+
+    //const uint16_t kmask1 = 0x3030;
+    //const uint16_t kmask2 = 0x0f0f;
+
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int ip  = tid/4;          // 0 or 1
+    const int il  = 2*((tid%4)/2);  // 0 or 2
+    const int ir  = tid%2;
+    const int n   = 8;
+    const int l0  = n*ir;
+
+    // One would think that the Metal compiler would figure out that ip and il can only have
+    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+    // with these two tales.
+    //
+    // Possible masks for the high bit
+    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
+                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
+                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
+                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+    // Possible masks for the low 2 bits
+    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+    const ushort4 hm = mm[2*ip + il/2];
+
+    const short shift = 2*il;
+
+    const float v1 = il == 0 ? 4.f : 64.f;
+    const float v2 = 4.f * v1;
+
+    const uint16_t s_shift1 = 4*ip;
+    const uint16_t s_shift2 = s_shift1 + il;
+
+    const int q_offset = 32*ip + l0;
+    const int y_offset = 128*ip + 32*il + l0;
+
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    uint32_t scales32, aux32;
+    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+    thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+    float sumf1[2] = {0.f};
+    float sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 4) {
+        for (int l = 0; l < 8; ++l) {
+            yl[l+ 0] = y1[l+ 0];
+            yl[l+ 8] = y1[l+16];
+            yl[l+16] = y1[l+32];
+            yl[l+24] = y1[l+48];
+        }
+
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+        device const half * dh = &x[i].d;
+
+        for (int row = 0; row < 2; ++row) {
+            const float d_all = (float)dh[0];
+
+            scales16[0] = a[4];
+            scales16[1] = a[5];
+            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+            scales16[0] = a[il+0];
+            scales16[1] = a[il+1];
+            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const int32_t qs = q[l/2];
+                s1 += yl[l+0] * (qs & qm[il/2][0]);
+                s2 += yl[l+1] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+                s4 += yl[l+16] * (qs & qm[il/2][2]);
+                s5 += yl[l+17] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
+            }
+            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[0] - 32);
+            sumf2[row] += d2 * (scales[2] - 32);
+
+            s1 = s2 = s3 = s4 = s5 = s6 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const int32_t qs = q[l/2+8];
+                s1 += yl[l+8] * (qs & qm[il/2][0]);
+                s2 += yl[l+9] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+                s4 += yl[l+24] * (qs & qm[il/2][2]);
+                s5 += yl[l+25] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+            }
+            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[1] - 32);
+            sumf2[row] += d2 * (scales[3] - 32);
+
+            q  += args.nb01/2;
+            h  += args.nb01/2;
+            a  += args.nb01/2;
+            dh += args.nb01/2;
+        }
+
+        y1 += 4 * QK_K;
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+        sumf1[row] = simd_sum(sumf);
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    if (tiisg == 0) {
+        for (int row = 0; row < 2; ++row) {
+            dst_f32[first_row + row] = sumf1[row];
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q4_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int iq = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = r0 * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
+
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
+            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+        }
+
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            device const uint16_t * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += args.nb01/2;
+            sc += args.nb01/2;
+            dh += args.nb01/2;
+        }
+
+        y4 += 4 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q5_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
+
+    float sumf[2]={0.f};
+
+    float yl[16], yh[16];
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int iq  = tid/4;
+    const int ir  = tid%4;
+    const int n   = 8;
+
+    const int l0 = n*ir;
+    const int q_offset = 32*iq + l0;
+    const int y_offset = 64*iq + l0;
+
+    const uint8_t hm1 = 1u << (2*iq);
+    const uint8_t hm2 = hm1 << 1;
+    const uint8_t hm3 = hm1 << 4;
+    const uint8_t hm4 = hm2 << 4;
+
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    for (int i = ix; i < nb; i += 4) {
+        device const uint8_t * q1 = x[i].qs + q_offset;
+        device const uint8_t * qh = x[i].qh + l0;
+        device const half * dh = &x[i].d;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
+
+        device const float * y2 = y1 + 128;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+        }
+
+        for (int row = 0; row < 2; ++row) {
+            device const uint8_t * q2 = q1 + 64;
+
+            sc16[0] = a[0] & kmask1;
+            sc16[1] = a[2] & kmask1;
+            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+            float4 acc1 = {0.f};
+            float4 acc2 = {0.f};
+            for (int l = 0; l < n; ++l) {
+                uint8_t h = qh[l];
+                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
+            }
+            const float dall = dh[0];
+            const float dmin = dh[1];
+            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
+                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
+                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += args.nb01;
+            qh += args.nb01;
+            dh += args.nb01/2;
+            a  += args.nb01/2;
+        }
+
+        y1 += 4 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = tot;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template 
+void kernel_mul_mv_q6_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const uint8_t kmask1 = 0x03;
+    const uint8_t kmask2 = 0x0C;
+    const uint8_t kmask3 = 0x30;
+    const uint8_t kmask4 = 0xC0;
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int row = 2*r0 + sgitg;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =  r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
+
+    float sumf = 0;
+
+    const int tid  = tiisg/2;
+    const int ix   = tiisg%2;
+    const int ip   = tid/8;         // 0 or 1
+    const int il   = tid%8;
+    const int n    = 4;
+    const int l0   = n*il;
+    const int is   = 8*ip + l0/16;
+
+    const int y_offset = 128*ip + l0;
+    const int q_offset_l = 64*ip + l0;
+    const int q_offset_h = 32*ip + l0;
+
+    for (int i = ix; i < nb; i += 2) {
+        device const uint8_t * q1 = x[i].ql + q_offset_l;
+        device const uint8_t * q2 = q1 + 32;
+        device const uint8_t * qh = x[i].qh + q_offset_h;
+        device const int8_t  * sc = x[i].scales + is;
+
+        device const float * y = yy + i * QK_K + y_offset;
+
+        const float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < n; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+
+        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst_f32[row] = tot;
+    }
+}
+
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+// ======================= "True" 2-bit
+
+template
+void kernel_mul_mv_iq2_xxs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
+    device const float         * y = (device const float         *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_xxs * xr = x + ibl;
+        device const uint16_t * q2 = xr->qs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            device const uint8_t * aux8 = (device const uint8_t *)q2;
+            const uint32_t aux32 = q2[2] | (q2[3] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float sum = 0;
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
+                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 8; ++j) {
+                    sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * sum;
+
+            dh += args.nb01/2;
+            q2 += args.nb01/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq2_xs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 512);
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_xs * xr = x + ibl;
+        device const uint16_t * q2 = xr->qs + 4 * ib;
+        device const uint8_t  * sc = xr->scales + ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const uint8_t ls1 = sc[0] & 0xf;
+            const uint8_t ls2 = sc[0] >>  4;
+            const float d1 = db * (0.5f + ls1);
+            const float d2 = db * (0.5f + ls2);
+
+            float sum1 = 0, sum2 = 0;
+            for (int l = 0; l < 2; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+                const uint8_t signs = ssigns[(q2[l] >> 9)];
+                for (int j = 0; j < 8; ++j) {
+                    sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            for (int l = 2; l < 4; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+                const uint8_t signs = ssigns[(q2[l] >> 9)];
+                for (int j = 0; j < 8; ++j) {
+                    sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d1 * sum1 + d2 * sum2;
+
+            dh += args.nb01/2;
+            q2 += args.nb01/2;
+            sc += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template 
+void kernel_mul_mv_iq3_xxs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
+    device const float         * y = (device const float         *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_xxs * xr = x + ibl;
+        device const uint8_t  * q3 = xr->qs + 8 * ib;
+        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+            const float db = dh[0];
+            const uint32_t aux32 = gas[0] | (gas[1] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
+                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh  += args.nb01/2;
+            q3  += args.nb01;
+            gas += args.nb01/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.5f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq3_s_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 8 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + (ib/2);
+        device const uint8_t * signs = xr->signs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
+                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh    += args.nb01/2;
+            qs    += args.nb01;
+            qh    += args.nb01;
+            sc    += args.nb01;
+            signs += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_s_f32")]]
+kernel void kernel_mul_mv_iq3_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template 
+void kernel_mul_mv_iq2_s_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
+    //{
+    //    int nval = 32;
+    //    int pos  = (32*sgitg + tiisg)*nval;
+    //    for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 4 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + ib;
+        device const uint8_t * signs = qs + QK_K/8;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d1 = db * (0.5f + (sc[0] & 0xf));
+            const float d2 = db * (0.5f + (sc[0] >>  4));
+
+            float2 sum = {0};
+            for (int l = 0; l < 2; ++l) {
+                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
+                    sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
+                }
+            }
+            sumf[row] += d1 * sum[0] + d2 * sum[1];
+
+            dh    += args.nb01/2;
+            qs    += args.nb01;
+            qh    += args.nb01;
+            sc    += args.nb01;
+            signs += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_s_f32")]]
+kernel void kernel_mul_mv_iq2_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq1_s_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float sumy = 0;
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+            sumy += yl[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_s * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint16_t * qh = xr->qh + ib;
+        device const half     * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
+
+            float sum = 0;
+            for (int j = 0; j < 4; ++j) {
+                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
+
+            dh += args.nb01/2;
+            qs += args.nb01;
+            qh += args.nb01/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+template 
+void kernel_mul_mv_iq1_m_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    iq1m_scale_t scale;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float4 sumy = {0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_m * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint8_t  * qh = xr->qh + 2 * ib;
+        device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+        for (int row = 0; row < N_DST; row++) {
+            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+            float2 sum = {0.f};
+            for (int j = 0; j < 4; ++j) {
+                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+
+            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+
+            sc += args.nb01/2;
+            qs += args.nb01;
+            qh += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+template
+void kernel_mul_mv_iq4_nl_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK4_NL;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0 or 1
+
+    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK4_NL + it * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ib = ix; ib < nb; ib += 16) {
+
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+
+            device const block_iq4_nl & xb = x[row*nb + ib];
+            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = q4[0] | (q4[1] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = q4[2] | (q4[3] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 16 * QK4_NL;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+template
+void kernel_mul_mv_iq4_xs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    const int ix = tiisg/16;  // 0 or 1
+    const int it = tiisg%16;  // 0...15
+    const int ib = it/2;
+    const int il = it%2;
+
+    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ibl = ix; ibl < nb; ibl += 2) {
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2; ++row) {
+            device const block_iq4_xs & xb = x[row*nb + ibl];
+            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = (q4[0]     ) & 0x0f0f0f0f;
+            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = (q4[1]     ) & 0x0f0f0f0f;
+            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
+            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 2 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_xs_f32")]]
+kernel void kernel_mul_mv_iq4_xs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+kernel void kernel_get_rows_q(
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+        float4x4 temp;
+        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+template
+kernel void kernel_get_rows_f(
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
+        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
+    }
+}
+
+kernel void kernel_get_rows_i32(
+        device const  void * src0,
+        device const  void * src1,
+        device     int32_t * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
+        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+    }
+}
+
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template
+kernel void kernel_mul_mm(
+        constant ggml_metal_kargs_mul_mm & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup T     * sa = (threadgroup T     *)(shmem);
+    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+    const int r0 = tgpig.y;
+    const int r1 = tgpig.x;
+    const int im = tgpig.z;
+
+    // if this block is of 64x32 shape or smaller
+    const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_T8x8     ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 mc[8];
+
+    for (short i = 0; i < 8; i++){
+        mc[i] = make_filled_simdgroup_matrix(0.f);
+    }
+
+    short il = (tiitg % THREAD_PER_ROW);
+
+    const int i12 = im%args.ne12;
+    const int i13 = im/args.ne12;
+
+    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const short    offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0
+        + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
+
+    device const float   * y = (device const float   *)(src1
+        + args.nb13*i13
+        + args.nb12*i12
+        + args.nb11*(r1*BLOCK_SIZE_N + thread_col)
+        + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        T4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        #pragma unroll(16)
+        for (short i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+            +                     (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+            +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup const T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
+        threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
+
+        #pragma unroll(4)
+        for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
+            #pragma unroll(4)
+            for (short i = 0; i < 4; i++) {
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+            }
+
+            simdgroup_barrier(mem_flags::mem_none);
+
+            #pragma unroll(2)
+            for (short i = 0; i < 2; i++) {
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+            }
+
+            #pragma unroll(8)
+            for (short i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
+            }
+
+            lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
+            lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
+        }
+    }
+
+    if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
+        device float * C = (device float *) dst +
+            (BLOCK_SIZE_M * r0 + 32*(sgitg &  1)) + \
+            (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
+
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
+        }
+    } else {
+        // block is smaller than 64x32, we should avoid writing data outside of the matrix
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *) shmem) \
+                                     + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (sgitg == 0) {
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
+                device float4 * D4 = (device float4 *) D;
+
+                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
+                threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+                int i = 0;
+                for (; i < n_rows/4; i++) {
+                    *(D4 + i) = *(C4 + i);
+                }
+
+                i *= 4;
+                for (; i < n_rows; i++) {
+                    *(D + i) = *(C + i);
+                }
+            }
+        }
+    }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+// TODO: this kernel needs to be reimplemented from scratch for better performance
+template
+void kernel_mul_mm_id_impl(
+        int32_t  ne00,
+        int32_t  ne02,
+        uint64_t nb01,
+        uint64_t nb02,
+        int32_t  ne11,
+        int32_t  ne12,
+        uint64_t nb10,
+        uint64_t nb11,
+        uint64_t nb12,
+        int32_t  ne0,
+        int32_t  ne1,
+        int64_t  ne0ne1,
+        device   const char * src0,
+        device   const char * src1,
+        threadgroup ushort2 * rowids,
+        device         char * dst,
+        threadgroup    char * shmem,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shmem);
+    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+    const int r0 = tgpig.y;
+    const int r1 = tgpig.x;
+
+    if (r1*BLOCK_SIZE_N >= ne1) return;
+
+    // if this block is of 64x32 shape or smaller
+    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_half8x8  ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 mc[8];
+    for (int i = 0; i < 8; i++){
+        mc[i] = make_filled_simdgroup_matrix(0.f);
+    }
+    short il = (tiitg % THREAD_PER_ROW);
+
+    ushort offset1 = il/nl;
+
+    threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * id[1]
+        + nb11 * (id[0] % ne11)
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        half4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        for (int i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+        #pragma unroll(BLOCK_SIZE_K/8)
+        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            #pragma unroll(4)
+            for (int i = 0; i < 4; i++) {
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+            }
+            simdgroup_barrier(mem_flags::mem_none);
+            #pragma unroll(2)
+            for (int i = 0; i < 2; i++) {
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+            }
+
+            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+            #pragma unroll(8)
+            for (int i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
+            }
+        }
+    }
+
+    {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *) shmem) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (sgitg == 0) {
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+                int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
+
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + joff;
+                device float4 * D4 = (device float4 *) D;
+
+                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
+                threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+                int i = 0;
+                for (; i < n_rows/4; i++) {
+                    *(D4 + i) = *(C4 + i);
+                }
+
+                i *= 4;
+                for (; i < n_rows; i++) {
+                    *(D + i) = *(C + i);
+                }
+            }
+        }
+    }
+}
+
+template
+kernel void kernel_mul_mm_id(
+        constant ggml_metal_kargs_mul_mm_id & args,
+        device const char * src0s,
+        device const char * src1,
+        device       char * dst,
+        device const char * ids,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int32_t i02 = tgpig.z;
+
+    tgpig.z = 0;
+
+    device const char * src0 = src0s + i02*args.nb02;
+
+    // row indices
+    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
+
+    // TODO: parallelize this loop
+    int32_t _ne1 = 0;
+    for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
+        for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
+            int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
+            if (id == i02) {
+                if (tiitg == 0) {
+                    rowids[_ne1] = ushort2(ii0, ii1);
+                }
+                _ne1++;
+            }
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    kernel_mul_mm_id_impl(
+        args.ne00,
+        args.ne02,
+        args.nb01,
+        args.nb02,
+        args.ne11,
+        args.ne12,
+        args.nb10,
+        args.nb11,
+        args.nb12,
+        args.ne0,
+        _ne1,
+        (int64_t)args.ne0*args.ne1,
+        src0,
+        src1,
+        rowids,
+        dst,
+        shmem,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
+#define QK_NL 16
+
+//
+// get rows
+//
+
+typedef decltype(kernel_get_rows_f) get_rows_f_t;
+
+template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f;
+template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f;
+#endif
+
+typedef decltype(kernel_get_rows_q) get_rows_q_t;
+
+template [[host_name("kernel_get_rows_q4_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q6_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q;
+
+//
+// matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mat_mm_t kernel_mul_mm;
+#endif
+template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm_id) mat_mm_id_t;
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+#endif
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id;
+
+//
+// matrix-vector multiplication
+//
+
+typedef void (kernel_mul_mv_impl_t)(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg);
+
+typedef void (kernel_mul_mv2_impl_t)(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg);
+
+template
+void mmv_fn(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiitg,
+        ushort tiisg,
+        ushort sgitg) {
+    impl_fn(args, src0, src1, dst, tgpig, tiisg);
+}
+
+template
+void mmv_fn(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiitg,
+        ushort tiisg,
+        ushort sgitg) {
+    impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+typedef decltype(mmv_fn>) mul_mv_impl_fn_t;
+
+template
+kernel void kernel_mul_mv_id(
+        constant ggml_metal_kargs_mul_mv_id & args,
+        device const char * src0s,
+        device const char * src1,
+        device       char * dst,
+        device const char * ids,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const int iid1 = tgpig.z/args.nei0;
+    const int idx  = tgpig.z%args.nei0;
+
+    tgpig.z = 0;
+
+    const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
+
+    const int64_t i11 = idx % args.ne11;
+    const int64_t i12 = iid1;
+
+    const int64_t i1 = idx;
+    const int64_t i2 = i12;
+
+    device const char * src0_cur = src0s + i02*args.nb02;
+    device const char * src1_cur = src1  + i11*args.nb11 + i12*args.nb12;
+
+    device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
+
+    ggml_metal_kargs_mul_mv args0 = {
+        /*.ne00 =*/ args.ne00,
+        /*.ne01 =*/ args.ne01,
+        /*.ne02 =*/ 1, // args.ne02,
+        /*.nb00 =*/ args.nb00,
+        /*.nb01 =*/ args.nb01,
+        /*.nb02 =*/ args.nb02,
+        /*.nb03 =*/ args.nb02, // args.ne02 == 1
+        /*.ne10 =*/ args.ne10,
+        /*.ne11 =*/ 1, // args.ne11,
+        /*.ne12 =*/ 1, // args.ne12,
+        /*.nb10 =*/ args.nb10,
+        /*.nb11 =*/ args.nb11,
+        /*.nb12 =*/ args.nb12,
+        /*.nb13 =*/ args.nb12, // ne12 == 1
+        /*.ne0  =*/ args.ne0,
+        /*.ne1  =*/ 1, // args.ne1,
+        /*.r2   =*/ 1,
+        /*.r3   =*/ 1,
+    };
+
+    impl_fn(
+        args0,
+        /* src0 */ src0_cur,
+        /* src1 */ src1_cur,
+        /* dst  */ dst_cur,
+        shmem,
+        tgpig,
+        tiitg,
+        tiisg,
+        sgitg);
+}
+
+typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+#endif
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+
+kernel void kernel_pool_2d_max_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+
+    float res = -INFINITY;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            res = MAX(res, i_ptr[i * IW + j]);
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
+
+kernel void kernel_pool_2d_avg_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+    // const float scale = 1. / ((eh - bh) * (ew - bw));
+    const float scale = 1. / (k0 * k1);
+
+    float res = 0;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            float cur = i_ptr[i * IW + j];
+            res += cur * scale;
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
diff --git a/llama/ggml-metal-embed_darwin_arm64.s b/llama/ggml-metal-embed_darwin_arm64.s
new file mode 100644
index 000000000..a108c8251
--- /dev/null
+++ b/llama/ggml-metal-embed_darwin_arm64.s
@@ -0,0 +1,6 @@
+.section __DATA, __ggml_metallib
+.globl _ggml_metallib_start
+_ggml_metallib_start:
+.incbin "ggml-metal-embed.metal"
+.globl _ggml_metallib_end
+_ggml_metallib_end:
\ No newline at end of file
diff --git a/llama/ggml-metal-impl.h b/llama/ggml-metal-impl.h
new file mode 100644
index 000000000..19103fb59
--- /dev/null
+++ b/llama/ggml-metal-impl.h
@@ -0,0 +1,314 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef GGML_METAL_IMPL
+#define GGML_METAL_IMPL
+
+// kernel argument structs
+//
+// - element counters (e.g. ne00) typically use int32_t to reduce register usage
+//   however, be careful from int overflows when using those in the kernel implementation
+//
+// - strides (e.g. nb00) use uint64_t
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  dim;
+} ggml_metal_kargs_concat;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+} ggml_metal_kargs_bin;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_repeat;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_cpy;
+
+typedef struct {
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+    bool     inplace;
+} ggml_metal_kargs_set;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  n_past;
+    int32_t  n_dims;
+    int32_t  n_ctx_orig;
+    float    freq_base;
+    float    freq_scale;
+    float    ext_factor;
+    float    attn_factor;
+    float    beta_fast;
+    float    beta_slow;
+} ggml_metal_kargs_rope;
+
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    uint64_t nb_12_1;
+    uint64_t nb_12_2;
+    uint64_t nb_12_3;
+    uint64_t nb31;
+    int32_t  ne1;
+    int32_t  ne2;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint16_t n_head_log2;
+    float    logit_softcap;
+} ggml_metal_kargs_flash_attn_ext;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mv;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+    int16_t  nsg;
+    int16_t  nxpsg;
+    int16_t  r1ptg;
+} ggml_metal_kargs_mul_mv_ext;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+} ggml_metal_kargs_mul_mm_id;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+    uint64_t nb1;
+} ggml_metal_kargs_mul_mv_id;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_norm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_rms_norm;
+
+#endif // GGML_METAL_IMPL
diff --git a/llama/ggml-metal.h b/llama/ggml-metal.h
new file mode 100644
index 000000000..c3e7023e4
--- /dev/null
+++ b/llama/ggml-metal.h
@@ -0,0 +1,92 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+// Note: this description is outdated
+//
+// An interface allowing to compute ggml_cgraph with Metal
+//
+// This is a fully functional interface that extends ggml with GPU support for Apple devices.
+// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
+//
+// How it works?
+//
+// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
+// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
+// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
+//
+// You only need to make sure that all memory buffers that you used during the graph creation
+// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
+// used during the graph evaluation to determine the arguments of the compute kernels.
+//
+// Synchronization between device and host memory (for example for input and output tensors)
+// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
+//
+
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+#include 
+#include 
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+//
+// backend API
+// user-code should use only these functions
+//
+
+GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
+
+GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
+
+GGML_DEPRECATED(
+        GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
+        "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
+
+GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
+
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
+
+// helper to check if the device supports a specific family
+// ideally, the user code should be doing these checks
+// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
+
+// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
+GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
+
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml-metal.metal b/llama/ggml-metal.metal
new file mode 100644
index 000000000..1bca09725
--- /dev/null
+++ b/llama/ggml-metal.metal
@@ -0,0 +1,6803 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_DECL_METAL
+#define GGML_COMMON_IMPL_METAL
+#if defined(GGML_METAL_EMBED_LIBRARY)
+__embed_ggml-common.h__
+#else
+// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
+#include "../ggml-common.h"
+#endif
+#include "ggml-metal-impl.h"
+
+#include 
+
+using namespace metal;
+
+#define MAX(x, y) ((x) > (y) ? (x) : (y))
+#define MIN(x, y) ((x) < (y) ? (x) : (y))
+#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
+// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+//
+// cmd:
+//   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal/ggml-metal.metal
+//   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
+//
+#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16)
+#undef GGML_METAL_USE_BF16
+#endif
+
+#if defined(GGML_METAL_USE_BF16)
+typedef matrix bfloat4x4;
+#endif
+
+constexpr constant static float kvalues_iq4nl_f[16] = {
+    -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
+};
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template 
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+    reg = (type4x4)(*src);
+}
+
+template 
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+    reg = (type4x4)(*src);
+}
+
+template 
+void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*(src + il));
+}
+
+#if defined(GGML_METAL_USE_BF16)
+template 
+void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
+    reg = (type4x4)(*src);
+}
+#endif
+
+template 
+void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
+        reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i = 0; i < 2; i++) {
+        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
+        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
+    }
+}
+
+template 
+void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = il ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = il ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
+        reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i = 0; i < 2; i++) {
+        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
+        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
+    }
+}
+
+template 
+void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
+        reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = (il/4) ? 4 : 0;
+
+    const int gh_mv = (il/4) ? 12 : 0;
+    const int gh_bk = (il/4) ?  0 : 4;
+
+    for (int ii = 0; ii < 2; ii++) {
+        int i = 2*(il%4) + ii;
+
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[2*ii + 0] = d * x0 + md;
+        reg[2*ii + 1] = d * x1 + md;
+    }
+}
+
+template 
+void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
+        reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = (il/4) ? 4 : 0;
+
+    const int gh_mv = (il/4) ? 12 : 0;
+    const int gh_bk = (il/4) ?  0 : 4;
+
+    for (int ii = 0; ii < 2; ii++) {
+        int i = 2*(il%4) + ii;
+
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[2*ii + 0] = d * x0 + m;
+        reg[2*ii + 1] = d * x1 + m;
+    }
+}
+
+template 
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const float d = xb->d;
+
+    float4x4 reg_f;
+
+    for (int i = 0; i < 16; i++) {
+        reg_f[i/4][i%4] = (qs[i + 16*il] * d);
+    }
+
+    reg = (type4x4) reg_f;
+}
+
+template 
+void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const float d = xb->d;
+
+    for (int i = 0; i < 4; i++) {
+        reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
+    }
+}
+
+template 
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+    const float d = xb->d;
+    const float min = xb->dmin;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    float dl, ml;
+    uint8_t sc = xb->scales[il];
+
+    q = q + 32*(il/8) + 16*(il&1);
+    il = (il/2)%4;
+
+    half  coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    uchar mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template 
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * q = (device const uint8_t *)xb->qs;
+    device const uint8_t * h = (device const uint8_t *)xb->hmask;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+    q = q + 32 * (il/8) + 16 * (il&1);
+    h = h + 16 * (il&1);
+    uint8_t m = 1 << (il/2);
+    uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+                                 ((il/4)>0 ? 12  : 3);
+    uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+    uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+    int16_t  dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+                               : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+    float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+    const float ml = 4.f * dl;
+
+    il = (il/2) & 3;
+    const half    coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+    const uint8_t mask = il>1 ? (il>2 ? 192    : 48)     : (il>0 ? 12    : 3);
+    dl *= coef;
+
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+    }
+}
+
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+    return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+                 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
+template 
+void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
+    device const uchar * q = xb->qs;
+
+    short is = (il/4) * 2;
+    q = q + (il/4) * 32 + 16 * (il&1);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d   = il < 2 ? xb->d : xb->d / 16.h;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
+
+    const ushort mask = il < 2 ? 0x0F : 0xF0;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+    }
+}
+
+template 
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q  = xb->qs;
+    device const uint8_t * qh = xb->qh;
+
+    short is = (il/4) * 2;
+    q  = q + 32 * (il/4) + 16 * (il&1);
+    qh = qh + 16 * (il&1);
+    uint8_t ul = 1 << (il/2);
+    il = il & 3;
+    const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+    const float d = il < 2 ? xb->d : xb->d / 16.f;
+    const float min = xb->dmin;
+    const float dl = d * sc[0];
+    const float ml = min * sc[1];
+
+    const ushort mask  = il<2 ? 0x0F : 0xF0;
+    const float qh_val = il<2 ? 16.f : 256.f;
+    for (int i = 0; i < 16; ++i) {
+        reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+    }
+}
+
+template 
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+    const half d_all = xb->d;
+    device const uint8_t * ql = (device const uint8_t *)xb->ql;
+    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const int8_t * scales = (device const int8_t *)xb->scales;
+
+    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+    qh = qh + 32*(il/8) + 16*(il&1);
+    float sc = scales[(il%2) + 2 * ((il/2))];
+    il = (il/2) & 3;
+
+    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
+    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const float ml = d_all * sc * 32.f;
+    const float dl = d_all * sc * coef;
+    for (int i = 0; i < 16; ++i) {
+        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+        reg[i/4][i%4] = dl * q - ml;
+    }
+}
+
+template 
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+    const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+    thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+    const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+    uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+    signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+}
+
+template 
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint16_t * q2 = xb->qs + 4*ib32;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+    uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+    grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+    signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+    for (int i = 0; i < 8; ++i) {
+        reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+    }
+}
+
+template 
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * q3 = xb->qs + 8*ib32;
+    device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+    const uint32_t aux32 = gas[0] | (gas[1] << 16);
+    const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+    uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+    grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+    grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+    signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+        reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+    }
+}
+
+template 
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 8*ib32;
+    device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+    constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+        reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+    }
+    grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+    grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+    for (int i = 0; i < 4; ++i) {
+        reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+        reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+    }
+}
+
+template 
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const float d = xb->d;
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * signs = qs + QK_K/8;
+    const uint8_t qh = xb->qh[ib32] >> 4*il;
+    const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+    for (int i = 0; i < 8; ++i) {
+        reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+        reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+    }
+}
+
+template 
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    const float d = xb->d;
+    device const uint8_t  * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint16_t * qh = xb->qh;
+    const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+    const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+    const uint16_t h = qh[ib32] >> 6*il;
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml;
+    }
+}
+
+template 
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+    iq1m_scale_t scale;
+    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+    const float d = scale.f16;
+
+    device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+    device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+    const float dl  = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+    const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+    constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+    constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+    for (int i = 0; i < 4; ++i) {
+        reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+        reg[1][i] = dl * (grid1[i] >>  4) + ml1;
+        reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+        reg[3][i] = dl * (grid2[i] >>  4) + ml2;
+    }
+}
+
+template 
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+template 
+void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
+    reg[0] = d * kvalues_iq4nl_f[q8[0]];
+    reg[1] = d * kvalues_iq4nl_f[q8[1]];
+    reg[2] = d * kvalues_iq4nl_f[q8[2]];
+    reg[3] = d * kvalues_iq4nl_f[q8[3]];
+}
+
+template 
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+    // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+    const int ib32 = il/2;
+    il = il%2;
+    // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+    device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+    const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+    const float d = (float)xb->d * (ls - 32);
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    for (int i = 0; i < 4; ++i) {
+        aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+        reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+        reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+        reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+        reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+    }
+}
+
+enum ggml_sort_order {
+    GGML_SORT_ORDER_ASC,
+    GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+kernel void kernel_sub(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+kernel void kernel_mul(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+kernel void kernel_div(
+        constant ggml_metal_kargs_bin & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
+
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
+    }
+}
+
+template
+kernel void kernel_repeat(
+        constant ggml_metal_kargs_repeat & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    const int i03 = i3%args.ne03;
+    const int i02 = i2%args.ne02;
+    const int i01 = i1%args.ne01;
+
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i00 = i0%args.ne00;
+        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
+    }
+}
+
+typedef decltype(kernel_repeat) kernel_repeat_t;
+
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat;
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
+
+kernel void kernel_sub_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+        constant ggml_metal_kargs_bin & args,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
+    dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant     float  & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+        device const float * src0,
+        device       float * dst,
+        constant     float & min,
+        constant     float & max,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A     = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    // BEWARE !!!
+    // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+    // This was observed with Falcon 7B and 40B models
+    //
+    dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+        device const float4 * src0,
+        device       float4 * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+    dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_elu(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+    dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f);
+}
+
+kernel void kernel_sqr(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sqrt(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sqrt(src0[tpig]);
+}
+
+kernel void kernel_sin(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = cos(src0[tpig]);
+}
+
+kernel void kernel_sum_rows(
+        device const float * src0,
+        device       float * dst,
+        constant  int64_t & ne00,
+        constant  int64_t & ne01,
+        constant  int64_t & ne02,
+        constant  int64_t & ne03,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb03,
+        constant  int64_t & ne10,
+        constant  int64_t & ne11,
+        constant  int64_t & ne12,
+        constant  int64_t & ne13,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant  int64_t & ne0,
+        constant  int64_t & ne1,
+        constant  int64_t & ne2,
+        constant  int64_t & ne3,
+        constant uint64_t & nb0,
+        constant uint64_t & nb1,
+        constant uint64_t & nb2,
+        constant uint64_t & nb3,
+        uint3 tpig[[thread_position_in_grid]]) {
+    int64_t i3 = tpig.z;
+    int64_t i2 = tpig.y;
+    int64_t i1 = tpig.x;
+
+    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+        return;
+    }
+
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+
+    float row_sum = 0;
+
+    for (int64_t i0 = 0; i0 < ne00; i0++) {
+        row_sum += src_row[i0];
+    }
+
+    dst_row[0] = row_sum;
+}
+
+template
+kernel void kernel_soft_max(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float lmax = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+
+    // find the max value in the block
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float lsum = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+        lsum += exp_psrc0;
+        pdst[i00] = exp_psrc0;
+    }
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        pdst[i00] *= inv_sum;
+    }
+}
+
+template
+kernel void kernel_soft_max_4(
+        device const  char * src0,
+        device const  char * src1,
+        device        char * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    const int64_t i03 = (tgpig) / (ne02*ne01);
+    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+    float slope = 1.0f;
+
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float4 lmax4 = -INFINITY;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    }
+
+    const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+
+    const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+    // This barrier fixes a failing test
+    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    threadgroup_barrier(mem_flags::mem_none);
+
+    float sum = simd_sum(lsum);
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
+
+    const float inv_sum = 1.0f/sum;
+
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        pdst4[i00] *= inv_sum;
+    }
+}
+
+typedef decltype(kernel_soft_max)    kernel_soft_max_t;
+typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t;
+
+template [[host_name("kernel_soft_max_f16")]]   kernel kernel_soft_max_t   kernel_soft_max;
+template [[host_name("kernel_soft_max_f32")]]   kernel kernel_soft_max_t   kernel_soft_max;
+template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4;
+
+kernel void kernel_diag_mask_inf(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant       int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+    const int64_t i02 = tpig[2];
+    const int64_t i01 = tpig[1];
+    const int64_t i00 = tpig[0];
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        device const float4 * src0,
+        device       float4 * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne01,
+        constant        int & n_past,
+        uint3 tpig[[thread_position_in_grid]]) {
+
+    const int64_t i = 2*tpig[0];
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int64_t i4 = 4*i;
+    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        dst[i+1][k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            dst[i][k] = -INFINITY;
+        }
+    }
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
+// TODO: optimize
+kernel void kernel_ssm_conv_f32(
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        constant   int64_t & ne2,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t ir = tgpig.x;
+    const int64_t i2 = tgpig.y;
+    const int64_t i3 = tgpig.z;
+
+    const int64_t nc  = ne10;
+  //const int64_t ncs = ne00;
+  //const int64_t nr  = ne01;
+  //const int64_t n_t = ne1;
+  //const int64_t n_s = ne2;
+
+    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
+    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
+    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
+
+    float sumf = 0.0f;
+
+    for (int64_t i0 = 0; i0 < nc; ++i0) {
+        sumf += s[i0] * c[i0];
+    }
+
+    x[0] = sumf;
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
+// TODO: optimize
+kernel void kernel_ssm_scan_f32(
+        device const void * src0,
+        device const void * src1,
+        device const void * src2,
+        device const void * src3,
+        device const void * src4,
+        device const void * src5,
+        device      float * dst,
+        constant  int64_t & d_state,
+        constant  int64_t & d_inner,
+        constant  int64_t & n_seq_tokens,
+        constant  int64_t & n_seqs,
+        constant uint64_t & nb00,
+        constant uint64_t & nb01,
+        constant uint64_t & nb02,
+        constant uint64_t & nb10,
+        constant uint64_t & nb11,
+        constant uint64_t & nb12,
+        constant uint64_t & nb13,
+        constant uint64_t & nb20,
+        constant uint64_t & nb21,
+        constant uint64_t & nb22,
+        constant uint64_t & nb30,
+        constant uint64_t & nb31,
+        constant uint64_t & nb40,
+        constant uint64_t & nb41,
+        constant uint64_t & nb42,
+        constant uint64_t & nb50,
+        constant uint64_t & nb51,
+        constant uint64_t & nb52,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t ir = tgpig.x;
+    const int64_t i3 = tgpig.y;
+
+    const int64_t nc  = d_state;
+  //const int64_t nr  = d_inner;
+    const int64_t n_t = n_seq_tokens;
+  //const int64_t n_s = n_seqs;
+
+    for (int64_t i2 = 0; i2 < n_t; ++i2) {
+        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
+        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
+        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
+        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
+        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
+        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
+
+        if (i2 > 0) {
+            s0 = s;
+        }
+
+        // i1 == 0
+        float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
+        float x_dt = x[0] * dt_soft_plus;
+        float sumf = 0.0f;
+
+        for (int64_t i0 = 0; i0 < nc; ++i0) {
+            int64_t i = i0;
+            float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
+            sumf += state * C[i0];
+            s[i] = state;
+        }
+
+        y[0] = sumf;
+    }
+}
+
+kernel void kernel_argmax(
+        device   const void * x,
+        device      int32_t * dst,
+        constant    int64_t & ncols,
+        constant   uint64_t & nb01,
+        threadgroup   float * shared_maxval [[threadgroup(0)]],
+        threadgroup int32_t * shared_argmax [[threadgroup(1)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
+
+    float   lmax = -INFINITY;
+    int32_t larg = -1;
+
+    for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
+        if (x_row[i00] > lmax) {
+            lmax = x_row[i00];
+            larg = i00;
+        }
+    }
+
+    // find the argmax value in the block
+    float max_val = simd_max(lmax);
+    int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            shared_maxval[tiisg] = -INFINITY;
+            shared_argmax[tiisg] = -1;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            shared_maxval[sgitg] = max_val;
+            shared_argmax[sgitg] = arg_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = shared_maxval[tiisg];
+        arg_val = shared_argmax[tiisg];
+
+        float max_val_reduced   = simd_max(max_val);
+        int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
+
+        dst[tgpig] = arg_val_reduced;
+
+        return;
+    }
+
+    dst[tgpig] = arg_val;
+}
+
+kernel void kernel_norm(
+        constant ggml_metal_kargs_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float4 sumf4(0.0f);
+
+    float sumf = 0.0f;
+
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf4 += x[i00];
+    }
+    sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float mean = sumf/args.ne00;
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+
+    sumf = 0.0f;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] - mean;
+        sumf += dot(y[i00], y[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float variance = sumf/args.ne00;
+
+    const float scale = 1.0f/sqrt(variance + args.eps);
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+kernel void kernel_rms_norm(
+        constant ggml_metal_kargs_rms_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float sumf = 0.0f;
+
+    // parallel sum
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = simd_sum(sumf);
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float mean  = sumf/args.ne00;
+    const float scale = 1.0f/sqrt(mean + args.eps);
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        y[i00] = x[i00] * scale;
+    }
+}
+
+kernel void kernel_group_norm(
+        device const float * src0,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int32_t & n_groups,
+        constant     float & eps,
+        threadgroup float  * buf [[threadgroup(0)]],
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint sgitg[[simdgroup_index_in_threadgroup]],
+        uint tiisg[[thread_index_in_simdgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    const int64_t ne = ne00*ne01*ne02;
+    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+    int start = tgpig * gs;
+    int end   = start + gs;
+
+    start += tpitg;
+
+    if (end >= ne) {
+        end = ne;
+    }
+
+    float tmp = 0.0f; // partial sum for thread in warp
+
+    for (int j = start; j < end; j += ntg) {
+        tmp += src0[j];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float mean = tmp / gs;
+    tmp = 0.0f;
+
+    for (int j = start; j < end; j += ntg) {
+        float xi = src0[j] - mean;
+        dst[j] = xi;
+        tmp += xi * xi;
+    }
+
+    tmp = simd_sum(tmp);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            buf[sgitg] = tmp;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        tmp = buf[tiisg];
+        tmp = simd_sum(tmp);
+    }
+
+    const float variance = tmp / gs;
+    const float scale = 1.0f/sqrt(variance + eps);
+    for (int j = start; j < end; j += ntg) {
+        dst[j] *= scale;
+    }
+}
+
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
+
+    for (int i = 0; i < 8; i += 2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+
+    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
+}
+
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
+}
+
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+
+    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
+}
+
+// putting them in the kernel cause a significant performance penalty
+#define N_DST 4        // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
+//Note: This is a template, but strictly speaking it only applies to
+//      quantizations where the block size is 32. It also does not
+//      guard against the number of rows not being divisible by
+//      N_DST, so this is another explicit assumption of the implementation.
+template
+void mul_vec_q_n_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    const int nb = args.ne00/QK4_0;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * nsg + sgitg) * nr;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q_type * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
+    }
+
+    float yl[16]; // src1 vector cache
+    float sumf[nr] = {0.f};
+
+    const short ix = (tiisg/2);
+    const short il = (tiisg%2)*8;
+
+    device const float * yb = y + ix*QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += nw/2) {
+        float sumy[2] = { 0.f, 0.f };
+
+#pragma unroll
+        for (int i = 0; i < 8; i += 2) {
+            sumy[0]  += yb[i +  0] + yb[i +  1];
+            yl[i + 0] = yb[i +  0];
+            yl[i + 1] = yb[i +  1]/256.f;
+
+            sumy[1]  += yb[i + 16] + yb[i + 17];
+            yl[i + 8] = yb[i + 16]/16.f;
+            yl[i + 9] = yb[i + 17]/4096.f;
+        }
+
+#pragma unroll
+        for (int row = 0; row < nr; row++) {
+            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
+        }
+
+        yb += QK4_0 * 16;
+    }
+
+    device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+
+        if (tiisg == 0 && first_row + row < args.ne01) {
+            dst_f32[first_row + row] = tot;
+        }
+    }
+}
+
+kernel void kernel_mul_mv_q4_0_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q4_1_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q5_0_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+#define NB_Q8_0 8
+
+template
+void kernel_mul_mv_q8_0_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    const int nr  = N_DST;
+    const int nsg = N_SIMDGROUP;
+    const int nw  = N_SIMDWIDTH;
+
+    const int nb = args.ne00/QK8_0;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0*nsg + sgitg)*nr;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q8_0 * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
+    }
+
+    float yl[NB_Q8_0];
+    float sumf[nr] = { 0.f };
+
+    const short ix = tiisg/4;
+    const short il = tiisg%4;
+
+    device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
+
+    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+    for (int ib = ix; ib < nb; ib += nw/4) {
+        for (short i = 0; i < NB_Q8_0; ++i) {
+            yl[i] = yb[i];
+        }
+
+        for (int row = 0; row < nr; row++) {
+            device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
+            float sumq = 0.f;
+            for (short iq = 0; iq < NB_Q8_0; ++iq) {
+                sumq += qs[iq] * yl[iq];
+            }
+            sumf[row] += sumq*ax[row][ib].d;
+        }
+
+        yb += nw*NB_Q8_0;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < nr; ++row) {
+        const float tot = simd_sum(sumf[row]);
+
+        if (tiisg == 0 && first_row + row < args.ne01) {
+            dst_f32[first_row + row] = tot;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+// mat-vec kernel processing in chunks of float4
+// chpb - chunks per quantization block
+template
+void kernel_mul_mv_ext_q4_f32_impl(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short chpt = 4; // chunks per thread
+
+  //const short nxpsg = (32);
+    const short nypsg = (32/nxpsg);
+
+    const short tx = tiisg%nxpsg;
+    const short ty = tiisg/nxpsg;
+
+    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i11 = tgpig.y*r1ptg;
+    const int i1m = tgpig.z;
+
+    const int i12 = i1m%args.ne12;
+    const int i13 = i1m/args.ne12;
+
+    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+    device const float4 * y4[r1ptg];
+
+    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+        y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
+    }
+
+    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+    short cch = tx%chpb; // current chunk index
+
+    for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
+        float4 lx[chpt];
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+            deq_t4(xq, cch, lx[ch]);
+
+            cch += nxpsg;
+            if (cch >= chpb) {
+                xq  += cch/chpb;
+                cch %= chpb;
+            }
+        }
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+                sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
+
+            }
+        }
+
+#pragma unroll(r1ptg)
+        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+            y4[ir1] += chpt*nxpsg;
+        }
+    }
+
+    // reduce only the threads in each row
+    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+        if (nxpsg >= 32) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+        }
+        if (nxpsg >= 16) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
+        }
+        if (nxpsg >= 8) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
+        }
+        if (nxpsg >= 4) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
+        }
+        if (nxpsg >= 2) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
+        }
+
+        //sumf[ir1] = simd_sum(sumf[ir1]);
+    }
+
+    if (tx == 0) {
+        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+            if (i01 < args.ne01) {
+                dst_f32[i01] = sumf[ir1];
+            }
+        }
+    }
+}
+
+// mat-vec kernel processing in chunks of float4x4
+template
+void kernel_mul_mv_ext_q4x4_f32_impl(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short chpt = 1;
+
+  //const short nxpsg = (32);
+    const short nypsg = (32/nxpsg);
+
+    const short tx = tiisg%nxpsg;
+    const short ty = tiisg/nxpsg;
+
+    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i11 = tgpig.y*r1ptg;
+    const int i1m = tgpig.z;
+
+    const int i12 = i1m%args.ne12;
+    const int i13 = i1m/args.ne12;
+
+    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+    device const float4x4 * y4x4[r1ptg];
+
+    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+        y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
+    }
+
+    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+    short cch = tx%chpb;
+
+    for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
+        float4x4 lx[chpt];
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+            deq_t4x4(xq, cch, lx[ch]);
+
+            cch += nxpsg;
+            if (cch >= chpb) {
+                xq  += cch/chpb;
+                cch %= chpb;
+            }
+        }
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+                sumf[ir1] +=
+                    dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
+                    dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
+                    dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
+                    dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
+
+            }
+        }
+
+#pragma unroll(r1ptg)
+        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+            y4x4[ir1] += chpt*nxpsg;
+        }
+    }
+
+    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+        if (nxpsg >= 32) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+        }
+        if (nxpsg >= 16) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
+        }
+        if (nxpsg >= 8) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
+        }
+        if (nxpsg >= 4) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
+        }
+        if (nxpsg >= 2) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
+        }
+
+        //sumf[ir1] = simd_sum(sumf[ir1]);
+    }
+
+    if (tx == 0) {
+        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+            if (i01 < args.ne01) {
+                dst_f32[i01] = sumf[ir1];
+            }
+        }
+    }
+}
+
+// dispatchers needed for compile-time nxpsg
+// epb - elements per quantization block
+template
+kernel void kernel_mul_mv_ext_q4_f32_disp(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    switch (args.nxpsg) {
+        case 4:  kernel_mul_mv_ext_q4_f32_impl<4,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 8:  kernel_mul_mv_ext_q4_f32_impl<8,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+    }
+}
+
+template
+kernel void kernel_mul_mv_ext_q4x4_f32_disp(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    switch (args.nxpsg) {
+        case 4:  kernel_mul_mv_ext_q4x4_f32_impl<4,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 8:  kernel_mul_mv_ext_q4x4_f32_impl<8,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+    }
+}
+
+typedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
+typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>)    mul_mv_ext_q4x4_f32_t;
+
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4,        4,  dequantize_f16_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0,   32, dequantize_q4_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1,   32, dequantize_q4_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0,   32, dequantize_q5_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1,   32, dequantize_q5_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
+
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
+
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
+
+#define N_MV_T_T 4
+
+template
+void kernel_mul_mv_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg) {
+    const int r0 = tgpig.x;
+    const int rb = tgpig.y*N_MV_T_T;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+    device const T0 * x = (device const T0 *) (src0 + offset0);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+    if (args.ne00 < 128) {
+        for (int row = 0; row < N_MV_T_T; ++row) {
+            int r1 = rb + row;
+            if (r1 >= args.ne11) {
+                break;
+            }
+
+            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+            device const T1 * y = (device const T1 *) (src1 + offset1);
+
+            float sumf = 0;
+            for (int i = tiisg; i < args.ne00; i += 32) {
+                sumf += (T0) x[i] * (T1) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        device const T04 * x4 = (device const T04 *) x;
+        for (int row = 0; row < N_MV_T_T; ++row) {
+            int r1 = rb + row;
+            if (r1 >= args.ne11) {
+                break;
+            }
+
+            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+            device const T1  * y  = (device const T1  *) (src1 + offset1);
+            device const T14 * y4 = (device const T14 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < args.ne00/4; i += 32) {
+                sumf += dot((float4) x4[i], (float4) y4[i]);
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+template
+kernel void kernel_mul_mv(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_impl(
+        args,
+        src0,
+        src1,
+        dst,
+        tgpig,
+        tiisg);
+}
+
+typedef decltype(kernel_mul_mv) mul_mv_t;
+
+template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv;
+template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv;
+#endif
+
+template
+kernel void kernel_mul_mv_1row(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const T     * x = (device const T     *) (src0 + offset0);
+    device const float * y = (device const float *) (src1 + offset1);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    float sumf = 0;
+    if (args.ne00 < 128) {
+        for (int i = tiisg; i < args.ne00; i += 32) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst_f32[r0] = all_sum;
+        }
+    } else {
+        device const T4     * x4 = (device const T4     *) x;
+        device const float4 * y4 = (device const float4 *) y;
+
+        for (int i = tiisg; i < args.ne00/4; i += 32) {
+            sumf += dot((float4) x4[i], y4[i]);
+        }
+
+        float all_sum = simd_sum(sumf);
+
+        if (tiisg == 0) {
+            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+            dst_f32[r0] = all_sum;
+        }
+    }
+}
+
+typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row;
+#endif
+
+// Assumes row size (ne00) is a multiple of 4
+template
+kernel void kernel_mul_mv_l4(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+
+    const int nrows = args.ne11;
+    const int r0 = tgpig.x;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+
+    device const T4 * x4 = (device const T4 *) (src0 + offset0);
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
+
+        device const float4 * y4 = (device const float4 *) (src1 + offset1);
+
+        float sumf = 0;
+        for (int i = tiisg; i < args.ne00/4; i += 32) {
+            sumf += dot((float4) x4[i], y4[i]);
+        }
+
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+        }
+    }
+}
+
+typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t;
+
+template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4;
+#endif
+
+static float rope_yarn_ramp(const float low, const float high, const int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+static void rope_yarn(
+    float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
+    thread float * cos_theta, thread float * sin_theta) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    *cos_theta = cos(theta) * mscale;
+    *sin_theta = sin(theta) * mscale;
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+static void rope_yarn_corr_dims(
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
+    dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
+}
+
+template
+kernel void kernel_rope_norm(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[1];
+
+            dst_data[0] = x0*cos_theta - x1*sin_theta;
+            dst_data[1] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+template
+kernel void kernel_rope_neox(
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
+
+    float corr_dims[2];
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
+
+    device const int32_t * pos = (device const int32_t *) src1;
+
+    const float theta_base = (float) pos[i2];
+    const float inv_ndims = -1.f/args.n_dims;
+
+    float cos_theta;
+    float sin_theta;
+
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
+
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
+
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
+
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[args.n_dims/2];
+
+            dst_data[0]             = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
+        } else {
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+typedef decltype(kernel_rope_norm) kernel_rope_norm_t;
+typedef decltype(kernel_rope_neox) kernel_rope_neox_t;
+
+template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm;
+template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm;
+
+template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox;
+template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox;
+
+typedef void (im2col_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template 
+kernel void kernel_im2col(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+//    const int64_t IC = tgpg[0];
+    const int64_t OH = tgpg[1];
+    const int64_t OW = tgpg[2];
+
+//    const int64_t N  = ntg[0];
+    const int64_t KH = ntg[1];
+    const int64_t KW = ntg[2];
+
+    const int64_t in  = tpitg[0];
+    const int64_t ikh = tpitg[1];
+    const int64_t ikw = tpitg[2];
+
+    const int64_t iic = tgpig[0];
+    const int64_t ioh = tgpig[1];
+    const int64_t iow = tgpig[2];
+
+    const int64_t iiw = iow*s0 + ikw*d0 - p0;
+    const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+        pdst[offset_dst] = x[offset_src];
+    }
+}
+
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
+
+typedef void (im2col_ext_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template 
+kernel void kernel_im2col_ext(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
+    const int64_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+
+    const int64_t d = tgpig[0] / CHW;
+    const int64_t chw = tgpig[0] % CHW;
+    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
+    const int64_t HW = tgpig[0] % KHW;
+
+    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+    if (tpitg_0 >= N) {
+        return;
+    }
+
+    const int64_t tpitg_1 = HW / KW;
+    const int64_t tpitg_2 = HW % KW;
+
+    const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
+    const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+
+    const int64_t offset_dst =
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
+        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext;
+template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext;
+
+typedef void (conv_transpose_1d_t)(
+        device const float * src0,
+        device const float * src1,
+        device        char * dst,
+        constant   int32_t & IC,
+        constant   int32_t & IL,
+        constant   int32_t & K,
+        constant   int32_t & s0,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3    tgpg[[threadgroups_per_grid]]);
+
+template 
+kernel void kernel_conv_transpose_1d(
+        device const     T * src0,
+        device const float * src1,
+        device        char * dst,
+        constant   int32_t & IC,
+        constant   int32_t & IL,
+        constant   int32_t & K,
+        constant   int32_t & s0,
+        constant  uint64_t & nb0,
+        constant  uint64_t & nb1,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3   tgpg[[threadgroups_per_grid]]) {
+
+    float v = 0.0f;
+
+    for (int64_t c = 0; c < IC; c++) {
+        const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
+        const int32_t input_offset = c * IL;
+
+        for (int64_t i = 0; i < IL; i++) {
+            if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
+                v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+            }
+        }
+    }
+
+    device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+
+    dst_ptr[0] = v;
+}
+
+template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
+kernel void kernel_conv_transpose_1d(
+    device const float * src0,
+    device const float * src1,
+    device        char * dst,
+    constant   int32_t & IC,
+    constant   int32_t & IL,
+    constant   int32_t & K,
+    constant   int32_t & s0,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    uint3    tgpg[[threadgroups_per_grid]]);
+
+template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
+kernel void kernel_conv_transpose_1d(
+    device const half  * src0,
+    device const float * src1,
+    device        char * dst,
+    constant   int32_t & IC,
+    constant   int32_t & IL,
+    constant   int32_t & K,
+    constant   int32_t & s0,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    uint3    tgpg[[threadgroups_per_grid]]);
+
+kernel void kernel_upscale_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant     float & sf0,
+    constant     float & sf1,
+    constant     float & sf2,
+    constant     float & sf3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3/sf3;
+    const int64_t i02 = i2/sf2;
+    const int64_t i01 = i1/sf1;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        const int64_t i00 = i0/sf0;
+
+        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+
+        dst_ptr[0] = src0_ptr[0];
+    }
+}
+
+kernel void kernel_pad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            } else {
+                dst_ptr[i0] = 0.0f;
+            }
+        }
+
+        return;
+    }
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = 0.0f;
+    }
+}
+
+kernel void kernel_pad_reflect_1d_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant   int64_t & ne0,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    constant   int32_t & p0,
+    constant   int32_t & p1,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3  tgpg[[threadgroups_per_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < p0) {
+                dst_ptr[i0] = src0_ptr[p0 - i0];
+            } else if (i0 < ne0 - p1) {
+                dst_ptr[i0] = src0_ptr[i0 - p0];
+            } else {
+                dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+            }
+        }
+    }
+}
+
+kernel void kernel_unpad_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i3 = tgpig.z;
+    const int64_t i2 = tgpig.y;
+    const int64_t i1 = tgpig.x;
+
+    const int64_t i03 = i3;
+    const int64_t i02 = i2;
+    const int64_t i01 = i1;
+
+    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+
+    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+            if (i0 < ne00) {
+                dst_ptr[i0] = src0_ptr[i0];
+            }
+        }
+
+        return;
+    }
+}
+
+kernel void kernel_arange_f32(
+    device        char * dst,
+    constant   int64_t & ne0,
+    constant   float   & start,
+    constant   float   & step,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    device float * dst_ptr = (device float *) dst;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        dst_ptr[i0] = start + step * i0;
+    }
+}
+
+kernel void kernel_timestep_embedding_f32(
+    device  const char * src0,
+    device        char * dst,
+    constant  uint64_t & nb1,
+    constant  int      & dim,
+    constant  int      & max_period,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    int i = tgpig.x;
+    device float * embed_data = (device float *)(dst +  i*nb1);
+
+    int half_ = dim / 2;
+    for (int j = tpitg.x; j < half_; j += ntg.x) {
+        float timestep = ((device float *)src0)[i];
+        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float arg = timestep * freq;
+        embed_data[j        ] = cos(arg);
+        embed_data[j + half_] = sin(arg);
+    }
+
+    if (dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[dim] = 0.f;
+    }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+        device const float  * x,
+        device     int32_t  * dst,
+        constant   int64_t  & ncols,
+        constant   int64_t  & ncols_pad,
+        threadgroup int32_t * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template
+kernel void kernel_argsort_f32_i32(
+        device const float   * x,
+        device       int32_t * dst,
+        constant     int64_t & ncols,
+        constant     int64_t & ncols_pad,
+        threadgroup int32_t  * shared_values [[threadgroup(0)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]]) {
+    // bitonic sort
+    int col = tpitg[0];
+    int row = tgpig[1];
+
+    if (col >= ncols_pad) return;
+
+    device const float   * x_row   = x + row * ncols;
+    threadgroup int32_t  * dst_row = shared_values;
+
+    // initialize indices
+    dst_row[col] = col;
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        SWAP(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            threadgroup_barrier(mem_flags::mem_threadgroup);
+        }
+    }
+
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
+}
+
+template [[host_name("kernel_argsort_f32_i32_asc")]]  kernel argsort_t kernel_argsort_f32_i32;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
+
+kernel void kernel_leaky_relu_f32(
+        device const float * src0,
+        device       float * dst,
+        constant     float & slope,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
+// ref: https://arxiv.org/pdf/2307.08691.pdf
+template<
+    typename q_t,     // query types in shared memory
+    typename q4_t,
+    typename q8x8_t,
+    typename k_t,     // key types in shared memory
+    typename k4x4_t,
+    typename k8x8_t,
+    typename v_t,     // value types in shared memory
+    typename v4x4_t,
+    typename v8x8_t,
+    typename qk_t,    // Q*K types
+    typename qk8x8_t,
+    typename s_t,     // soft-max types
+    typename s8x8_t,
+    typename o_t,     // attention accumulation types
+    typename o4_t,
+    typename o8x8_t,
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // key type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short D,         // head size
+    short Q  = 8,    // queries per threadgroup
+    short KV = 8,    // key/value processed per each simdgroup
+    short C  = 32>   // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3   ntg[[threads_per_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const int iq3 = tgpig[2];
+    const int iq2 = tgpig[1];
+    const int iq1 = tgpig[0]*Q;
+
+    const short D4  = D/4;
+    const short D8  = D/8;
+    const short D16 = D/16;
+    const short NW  = N_SIMDWIDTH;
+    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
+
+    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
+    const short T  = D + 2*TS; // shared memory size per query in (half)
+
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*D); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*D); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*D); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*D); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+
+    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+
+    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    o8x8_t lo[D8];
+
+    // load heads from Q to shared memory
+    for (short j = sgitg; j < Q; j += nsg) {
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+
+        for (short i = tiisg; i < D4; i += NW) {
+            if (iq1 + j < args.ne01) {
+                sq4[j*D4 + i] = (q4_t) q4[i];
+            } else {
+                sq4[j*D4 + i] = (q4_t) 0.0f;
+            }
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D8; ++i) {
+        lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f);
+    }
+
+    // zero out shared memory SH
+    for (short j = 0; j < Q; ++j) {
+        for (short i = tiisg; i < SH; i += NW) {
+            ss[j*TS + i] = 0.0f;
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        half S[Q] = { [0 ... Q-1] = 0.0f };
+        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+
+        // thread indices inside the simdgroup
+        // TODO: see if we can utilize quad-group functions for better performance
+        //       https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
+        const short tx = tiisg%4;
+        const short ty = tiisg/4;
+
+        // broadcast kv
+        //const short rk2 = args.ne02/args.ne12;
+        //const short rk3 = args.ne03/args.ne13;
+
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+        // load the queries from shared memory into local memory
+        q8x8_t mq[D8];
+
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_load(mq[i], sq + i*8, D);
+        }
+
+        const bool has_mask = mask != q;
+
+        half slope = 1.0f;
+
+        // ALiBi
+        if (args.max_bias > 0.0f) {
+            const short h = iq2;
+
+            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+            slope = pow(base, exph);
+        }
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= args.ne11) {
+                break;
+            }
+
+            if (has_mask) {
+                // used to detect blocks full of -INF
+                half smax = -INFINITY;
+
+                // load the mask in shared memory
+                #pragma unroll(Q)
+                for (short j = 0; j < Q; ++j) {
+                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
+
+                    const half m = pm[ic + tiisg];
+
+                    ss[j*TS + C + tiisg] = m;
+                    smax = max(smax, m);
+                }
+
+                smax = simd_max(smax);
+
+                if (smax == -INFINITY) {
+                    continue;
+                }
+            }
+
+            // Q*K^T
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f);
+
+                    // this is compile-time check, so it does not have runtime overhead
+                    if (is_same::value) {
+                        // we can read directly from global memory
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                        #pragma unroll(D8)
+                        for (short i = 0; i < D8; ++i) {
+                            k8x8_t mk;
+                            simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+
+                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
+                        }
+                    } else {
+                        for (short ii = 0; ii < D16; ii += 4) {
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                            if (D16%4 == 0) {
+                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+                                {
+                                    k4x4_t tmp;
+                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                    sk4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                #pragma unroll(4)
+                                for (short k = 0; k < 4; ++k) {
+                                    k8x8_t mk;
+
+                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                }
+                            } else {
+                                if (ii + tx < D16) {
+                                    k4x4_t tmp;
+                                    deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+                                    sk4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                    k8x8_t mk;
+
+                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
+
+                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
+                                }
+                            }
+                        }
+                    }
+
+                    // cast qk_t -> s_t
+                    //s8x8_t mqks(1.0f);
+                    //simdgroup_multiply(mqks, mqk, mqks);
+                    //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
+
+                    simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
+                }
+            }
+
+            // online softmax
+            {
+                for (ushort j = 0; j < Q; ++j) {
+                    const half m = M[j];
+
+                    // scale and apply the logitcap / mask
+                    half s = ss[j*TS + tiisg]*args.scale;
+
+                    if (args.logit_softcap != 0.0f) {
+                        s = args.logit_softcap*precise::tanh(s);
+                    }
+
+                    // mqk = mqk + mask*slope
+                    s += slope*ss[j*TS + C + tiisg];
+
+                    M[j] = simd_max(max(M[j], s));
+
+                    const half ms = exp(m - M[j]);
+                    const half vs = exp(s - M[j]);
+
+                    S[j] = S[j]*ms + simd_sum(vs);
+
+                    // the P matrix from the paper (Q rows, C columns)
+                    ss[j*TS + tiisg] = vs;
+
+                    // create a QxQ diagonal matrix for rescaling the output
+                    if (tiisg == j) {
+                        ss[j*TS + 2*C + j] = ms;
+                    }
+                }
+            }
+
+            // O = diag(ms)*O
+            {
+                s8x8_t mm;
+                simdgroup_load(mm, ss + 2*C, TS, 0, false);
+
+                #pragma unroll(D8)
+                for (short i = 0; i < D8; ++i) {
+                    simdgroup_multiply(lo[i], mm, lo[i]);
+                }
+            }
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/8; ++cc) {
+                    s8x8_t ms;
+                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
+
+                    if (is_same::value) {
+                        // we can read directly from global memory
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                        #pragma unroll(D8)
+                        for (short i = 0; i < D8; ++i) {
+                            v8x8_t mv;
+                            simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+
+                            simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+                        }
+                    } else {
+                        for (short ii = 0; ii < D16; ii += 4) {
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                            if (D16%4 == 0) {
+                                // no need for bound checks
+                                {
+                                    v4x4_t tmp;
+                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+                                    sv4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                #pragma unroll(4)
+                                for (short k = 0; k < 4; ++k) {
+                                    v8x8_t mv;
+
+                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                }
+                            } else {
+                                if (ii + tx < D16) {
+                                    v4x4_t tmp;
+                                    deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+                                    sv4x4[4*ty + tx] = tmp;
+                                }
+
+                                simdgroup_barrier(mem_flags::mem_threadgroup);
+
+                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
+                                    v8x8_t mv;
+
+                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+
+                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        for (short j = 0; j < Q; ++j) {
+            if (tiisg == 0) {
+                ss[j*TS + 0] = S[j];
+                ss[j*TS + 1] = M[j];
+            }
+        }
+    }
+
+    // reduce the warps sequentially
+    for (ushort sg = 1; sg < nsg; ++sg) {
+        half S = { 0.0f };
+        half M = { -__FLT16_MAX__/2 };
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // each simdgroup stores its output to shared memory, reusing sq
+        if (sgitg == sg) {
+            for (short i = 0; i < D8; ++i) {
+                simdgroup_store(lo[i], so + i*8, D, 0, false);
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // the first simdgroup accumulates the results from the other simdgroups
+        if (sgitg == 0) {
+            for (short j = 0; j < Q; ++j) {
+                const half S0 = ss[j*TS +         0];
+                const half S1 = ss[j*TS + sg*SH + 0];
+
+                const half M0 = ss[j*TS +         1];
+                const half M1 = ss[j*TS + sg*SH + 1];
+
+                M = max(M0, M1);
+
+                const half ms0 = exp(M0 - M);
+                const half ms1 = exp(M1 - M);
+
+                S = S0*ms0 + S1*ms1;
+
+                if (tiisg == 0) {
+                    ss[j*TS + 0] = S;
+                    ss[j*TS + 1] = M;
+
+                    ss[j*TS + 2*C + j        ] = ms0;
+                    ss[j*TS + 2*C + j + sg*SH] = ms1;
+                }
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            {
+                s8x8_t ms0;
+                s8x8_t ms1;
+
+                simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
+                simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
+
+                #pragma unroll(D8)
+                for (short i = 0; i < D8; ++i) {
+                    o8x8_t t;
+
+                    simdgroup_load    (t, so + i*8, D, 0, false);
+                    simdgroup_multiply(t, ms1, t);
+
+                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+                }
+            }
+        }
+    }
+
+    // store result to shared memory (reuse sq)
+    if (sgitg == 0) {
+        for (short i = 0; i < D8; ++i) {
+            simdgroup_store(lo[i], so + i*8, D, 0, false);
+        }
+    }
+
+    device float4 * dst4 = (device float4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
+            const float S = ss[j*TS + 0];
+
+            for (short i = tiisg; i < D4; i += NW) {
+                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+            }
+        }
+    }
+}
+
+// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
+//       template to be able to explore different combinations
+//
+#define FA_TYPES \
+    half,  half4,   simdgroup_half8x8,  \
+    half,  half4x4, simdgroup_half8x8,  \
+    half,  half4x4, simdgroup_half8x8,  \
+    float,          simdgroup_float8x8, \
+    float,          simdgroup_float8x8, \
+    half,  half4,   simdgroup_half8x8
+
+typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+#endif
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
+
+#undef FA_TYPES
+
+template<
+    typename q4_t,    // query types in shared memory
+    typename q4x4_t,
+    typename k4x4_t,  // key types in shared memory
+    typename v4x4_t,  // value types in shared memory
+    typename qk_t,    // Q*K types
+    typename s_t,     // soft-max types
+    typename s4_t,
+    typename s4x4_t,
+    typename o4x4_t,  // attention accumulation types
+    typename kd4x4_t, // key type in device memory
+    short nl_k,
+    void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+    typename vd4x4_t, // key type in device memory
+    short nl_v,
+    void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+    short D,         // head size
+    short Q  = 1,    // queries per threadgroup
+    short C  = 32>   // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3   ntg[[threads_per_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short nsg = ntg.y; // number of simdgroups
+
+    const int iq3 = tgpig[2];
+    const int iq2 = tgpig[1];
+    const int iq1 = tgpig[0];
+
+    const short D4  = D/4;
+    const short D16 = D/16;
+    const short NW  = N_SIMDWIDTH;
+    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
+    const short SH  = 2*C;  // shared memory per simdgroup
+
+    const short T = D + nsg*SH; // shared memory size per query in (half)
+
+  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shmem_f16 +                0*D); // holds the query data
+    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shmem_f16 +                0*D); // same as above but in q4_t
+    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 +                0*D); // same as above but in q4x4_t
+    threadgroup s_t    * ss    = (threadgroup s_t    *) (shmem_f16 + sgitg*SH     + Q*D); // scratch buffer for attention
+    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH     + Q*D); // same as above but in s4_t
+    threadgroup half   * sm    = (threadgroup half   *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
+    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D      + Q*T); // scratch buffer for the results
+
+    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
+    o4x4_t lo[D16/NL];
+
+    // load heads from Q to shared memory
+    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+
+    for (short i = tiisg; i < D4; i += NW) {
+        if (iq1 < args.ne01) {
+            sq4[i] = (q4_t) q4[i];
+        } else {
+            sq4[i] = (q4_t) 0.0f;
+        }
+    }
+
+    // zero out lo
+    for (short i = 0; i < D16/NL; ++i) {
+        lo[i] = (o4x4_t) 0.0f;
+    }
+
+    // zero out shared memory SH
+    for (short i = tiisg; i < SH/4; i += NW) {
+        ss4[i] = (s4_t) 0.0f;
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    {
+        half S = 0.0f;
+        half M = -__FLT16_MAX__/2;
+
+        // thread indices inside the simdgroup
+        const short tx = tiisg%NL;
+        const short ty = tiisg/NL;
+
+        // broadcast kv
+        //const short rk2 = args.ne02/args.ne12;
+        //const short rk3 = args.ne03/args.ne13;
+
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+        // load the queries from shared memory into local memory
+        q4x4_t mq[D16/NL];
+
+        #pragma unroll(D16/NL)
+        for (short ii = 0; ii < D16; ii += NL) {
+            mq[ii/NL] = sq4x4[ii + tx];
+        }
+
+        const bool has_mask = mask != q;
+
+        // pointer to the mask
+        device const half * pm = (device const half *) (mask + iq1*args.nb31);
+
+        half slope = 1.0f;
+
+        // ALiBi
+        if (args.max_bias > 0.0f) {
+            const short h = iq2;
+
+            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
+
+            slope = pow(base, exph);
+        }
+
+        // loop over the KV cache
+        // each simdgroup handles blocks of Q rows and C columns
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+            const int ic = ic0 + C*sgitg;
+            if (ic >= args.ne11) {
+                break;
+            }
+
+            if (has_mask) {
+                sm[tiisg] = pm[ic + tiisg];
+            }
+
+            // Q*K^T
+            {
+                // each simdgroup processes 1 query and 4 (NW/NL) keys
+                for (short cc = 0; cc < C/4; ++cc) {
+                    qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
+
+                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                    #pragma unroll(D16/NL)
+                    for (short ii = 0; ii < D16; ii += NL) {
+                        const short i = ii + tx;
+
+                        k4x4_t mk;
+                        deq_k(pk + i/nl_k, i%nl_k, mk);
+
+                        // note: this is less precise than the version below
+                        //mqka[0] += dot(mq[ii/NL][0], mk[0]);
+                        //mqka[1] += dot(mq[ii/NL][1], mk[1]);
+                        //mqka[2] += dot(mq[ii/NL][2], mk[2]);
+                        //mqka[3] += dot(mq[ii/NL][3], mk[3]);
+
+                        mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
+                        mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
+                        mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
+                        mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
+                    }
+
+                    qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
+
+                    // simdgroup reduce
+                    // [ 0 ..  7] -> [ 0]
+                    // [ 8 .. 15] -> [ 8]
+                    // [16 .. 23] -> [16]
+                    // [24 .. 31] -> [24]
+                  //mqk += simd_shuffle_down(mqk, 16);
+                  //mqk += simd_shuffle_down(mqk,  8);
+                    mqk += simd_shuffle_down(mqk,  4);
+                    mqk += simd_shuffle_down(mqk,  2);
+                    mqk += simd_shuffle_down(mqk,  1);
+
+                    // mqk = mqk*scale + mask*slope
+                    if (tx == 0) {
+                        mqk *= args.scale;
+
+                        if (args.logit_softcap != 0.0f) {
+                            mqk = args.logit_softcap*precise::tanh(mqk);
+                        }
+
+                        mqk += sm[4*cc + ty]*slope;
+
+                        ss[4*cc + ty] = mqk;
+                    }
+                }
+            }
+
+            simdgroup_barrier(mem_flags::mem_threadgroup);
+
+            // online softmax
+            {
+                const half m = M;
+                const half s = ss[tiisg];
+
+                M = simd_max(max(M, s));
+
+                const half ms = exp(m - M);
+                const half vs = exp(s - M);
+
+                S = S*ms + simd_sum(vs);
+
+                // the P matrix from the paper (Q rows, C columns)
+                ss[tiisg] = vs;
+
+                // O = diag(ms)*O
+                #pragma unroll(D16/NL)
+                for (short ii = 0; ii < D16; ii += NL) {
+                    lo[ii/NL] *= ms;
+                }
+            }
+
+            simdgroup_barrier(mem_flags::mem_threadgroup);
+
+            // O = O + (Q*K^T)*V
+            {
+                for (short cc = 0; cc < C/4; ++cc) {
+                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
+
+                    const s4x4_t ms(ss[4*cc + ty]);
+
+                    #pragma unroll(D16/NL)
+                    for (short ii = 0; ii < D16; ii += NL) {
+                        const short i = ii + tx;
+
+                        v4x4_t mv;
+                        deq_v(pv4 + i/nl_v, i%nl_v, mv);
+
+                        lo[ii/NL] += mv*ms;
+                    }
+                }
+            }
+        }
+
+        // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
+        if (tiisg == 0) {
+            ss[0] = (s_t) S;
+            ss[1] = (s_t) M;
+        }
+    }
+
+    // simdgroup reduce
+    // [ 0,  8, 16, 24] -> [ 0]
+    // [ 1,  9, 17, 25] -> [ 1]
+    // [ 2, 10, 18, 26] -> [ 2]
+    // [ 3, 11, 19, 27] -> [ 3]
+    // [ 4, 12, 20, 28] -> [ 4]
+    // [ 5, 13, 21, 29] -> [ 5]
+    // [ 6, 14, 22, 30] -> [ 6]
+    // [ 7, 15, 23, 31] -> [ 7]
+    for (short ii = 0; ii < D16; ii += NL) {
+        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+
+        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+
+        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+
+        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // store results to shared memory
+    for (short i = tiisg; i < D16; i += NL) {
+        sr4x4[i] = lo[i/NL];
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    // parallel reduce
+    for (short r = nsg/2; r > 0; r >>= 1) {
+        if (sgitg < r) {
+            const half S0 = ss[       0];
+            const half S1 = ss[r*SH + 0];
+
+            const half M0 = ss[       1];
+            const half M1 = ss[r*SH + 1];
+
+            const half M = max(M0, M1);
+
+            const half ms0 = exp(M0 - M);
+            const half ms1 = exp(M1 - M);
+
+            const half S = S0*ms0 + S1*ms1;
+
+            if (tiisg == 0) {
+                ss[0] = S;
+                ss[1] = M;
+            }
+
+            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
+            for (short i = tiisg; i < D16; i += NW) {
+                sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
+            }
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    device float4x4 * dst44 = (device float4x4 *) dst;
+
+    // final rescale with 1/S and store to global memory
+    if (sgitg == 0) {
+        const float S = ss[0];
+
+        for (short i = tiisg; i < D16; i += NW) {
+            dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+        }
+    }
+}
+
+// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
+//       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
+//
+#define FA_TYPES \
+           half4,  half4x4, \
+                   half4x4, \
+                   half4x4, \
+    float,                  \
+    half,  half4,  half4x4, \
+                   half4x4
+
+typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
+#undef FA_TYPES
+
+template
+kernel void kernel_set(
+    constant ggml_metal_kargs_set & args,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    ushort3 tpitg[[thread_position_in_threadgroup]],
+    ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i13 = tgpig[2];
+    const int i12 = tgpig[1];
+    const int i11 = tgpig[0];
+
+    const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
+
+    const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
+    const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
+    const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
+
+    device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
+
+    for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
+        device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
+        dst_data[i10] = (T) src[0];
+    }
+}
+
+typedef decltype(kernel_set) kernel_set_t;
+
+template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set;
+template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set;
+
+template
+kernel void kernel_cpy(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+        dst_data[i00] = (T1) src[0];
+    }
+}
+
+typedef decltype(kernel_cpy) kernel_cpy_t;
+
+template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy;
+#endif
+template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy;
+template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy;
+#endif
+
+kernel void kernel_cpy_f32_q8_0(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
+
+    device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = src[j];
+            amax = MAX(amax, fabs(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK8_0].d = d;
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = src[j]*id;
+
+            dst_data[i00/QK8_0].qs[j] = round(x0);
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_0(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
+
+    device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_0].d = d;
+
+        for (int j = 0; j < QK4_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK4_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            dst_data[i00/QK4_0].qs[j]  = xi0;
+            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q4_1(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
+
+    device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < QK4_1; j++) {
+            const float v = src[j];
+            if (min > v) min = v;
+            if (max < v) max = v;
+        }
+
+        const float d = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK4_1].d = d;
+        dst_data[i00/QK4_1].m = min;
+
+        for (int j = 0; j < QK4_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            dst_data[i00/QK4_1].qs[j]  = xi0;
+            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_0(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
+
+    device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK5_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_0].d = d;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_0/2; ++j) {
+            const float x0 = src[0       + j]*id;
+            const float x1 = src[QK5_0/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_0].qh[j] = qh8[j];
+        }
+    }
+}
+
+kernel void kernel_cpy_f32_q5_1(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
+
+    device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float max = src[0];
+        float min = src[0];
+
+        for (int j = 1; j < QK5_1; j++) {
+            const float v = src[j];
+            min = v < min ? v : min;
+            max = v > max ? v : max;
+        }
+
+        const float d = (max - min) / 31;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        dst_data[i00/QK5_1].d = d;
+        dst_data[i00/QK5_1].m = min;
+
+        uint32_t qh = 0;
+        for (int j = 0; j < QK5_1/2; ++j) {
+            const float x0 = (src[0       + j] - min)*id;
+            const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+        }
+        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+        for (int j = 0; j < 4; ++j) {
+            dst_data[i00/QK5_1].qh[j] = qh8[j];
+        }
+    }
+}
+
+static inline int best_index_int8(int n, constant float * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+kernel void kernel_cpy_f32_iq4_nl(
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
+
+    device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < QK4_0; j++) {
+            const float v = src[j];
+            if (amax < fabs(v)) {
+                amax = fabs(v);
+                max  = v;
+            }
+        }
+
+        const float d = max / kvalues_iq4nl_f[0];
+        const float id = d ? 1.0f/d : 0.0f;
+
+        float sumqx = 0, sumq2 = 0;
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            const float x0 = src[0        + j]*id;
+            const float x1 = src[QK4_NL/2 + j]*id;
+
+            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
+
+            const float v0 = kvalues_iq4nl_f[xi0];
+            const float v1 = kvalues_iq4nl_f[xi1];
+            const float w0 = src[0        + j]*src[0        + j];
+            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+            sumq2 += w0*v0*v0 + w1*v1*v1;
+
+        }
+
+        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+    }
+}
+
+kernel void kernel_concat(
+    constant ggml_metal_kargs_concat & args,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    ushort3 tpitg[[thread_position_in_threadgroup]],
+    ushort3   ntg[[threads_per_threadgroup]]) {
+
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
+
+    int o[4] = {0, 0, 0, 0};
+    o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
+
+    device const float * x;
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+            x = (device const float *)(src0 + (i3       )*args.nb03 + (i2       )*args.nb02 + (i1       )*args.nb01 + (i0       )*args.nb00);
+        } else {
+            x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
+        }
+
+        device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+
+        *y = *x;
+    }
+}
+
+template
+void kernel_mul_mv_q2_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int iq = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+    const int is = (8*ir)/16;// 0 or 1
+
+    device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
+        }
+
+        device const uint8_t  * sc = (device const uint8_t  *)x[ib].scales + 8*iq + is;
+        device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
+                acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
+                acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
+                acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
+                acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
+                acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
+                acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
+                acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
+            }
+            float dall = dh[0];
+            float dmin = dh[1] * 1.f/16.f;
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
+                                 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
+                                 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
+                                 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
+                         dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
+
+            qs += args.nb01/2;
+            sc += args.nb01;
+            dh += args.nb01/2;
+        }
+
+        y4 += 4 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q3_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
+
+    float yl[32];
+
+    //const uint16_t kmask1 = 0x3030;
+    //const uint16_t kmask2 = 0x0f0f;
+
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int ip  = tid/4;          // 0 or 1
+    const int il  = 2*((tid%4)/2);  // 0 or 2
+    const int ir  = tid%2;
+    const int n   = 8;
+    const int l0  = n*ir;
+
+    // One would think that the Metal compiler would figure out that ip and il can only have
+    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
+    // with these two tales.
+    //
+    // Possible masks for the high bit
+    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
+                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
+                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
+                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
+
+    // Possible masks for the low 2 bits
+    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
+
+    const ushort4 hm = mm[2*ip + il/2];
+
+    const short shift = 2*il;
+
+    const float v1 = il == 0 ? 4.f : 64.f;
+    const float v2 = 4.f * v1;
+
+    const uint16_t s_shift1 = 4*ip;
+    const uint16_t s_shift2 = s_shift1 + il;
+
+    const int q_offset = 32*ip + l0;
+    const int y_offset = 128*ip + 32*il + l0;
+
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    uint32_t scales32, aux32;
+    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
+    thread const int8_t * scales = (thread const int8_t *)&scales32;
+
+    float sumf1[2] = {0.f};
+    float sumf2[2] = {0.f};
+    for (int i = ix; i < nb; i += 4) {
+        for (int l = 0; l < 8; ++l) {
+            yl[l+ 0] = y1[l+ 0];
+            yl[l+ 8] = y1[l+16];
+            yl[l+16] = y1[l+32];
+            yl[l+24] = y1[l+48];
+        }
+
+        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
+        device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
+        device const uint16_t * a = (device const uint16_t *)(x[i].scales);
+        device const half * dh = &x[i].d;
+
+        for (int row = 0; row < 2; ++row) {
+            const float d_all = (float)dh[0];
+
+            scales16[0] = a[4];
+            scales16[1] = a[5];
+            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
+            scales16[0] = a[il+0];
+            scales16[1] = a[il+1];
+            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
+
+            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const int32_t qs = q[l/2];
+                s1 += yl[l+0] * (qs & qm[il/2][0]);
+                s2 += yl[l+1] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
+                s4 += yl[l+16] * (qs & qm[il/2][2]);
+                s5 += yl[l+17] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
+            }
+            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[0] - 32);
+            sumf2[row] += d2 * (scales[2] - 32);
+
+            s1 = s2 = s3 = s4 = s5 = s6 = 0;
+            for (int l = 0; l < n; l += 2) {
+                const int32_t qs = q[l/2+8];
+                s1 += yl[l+8] * (qs & qm[il/2][0]);
+                s2 += yl[l+9] * (qs & qm[il/2][1]);
+                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
+                s4 += yl[l+24] * (qs & qm[il/2][2]);
+                s5 += yl[l+25] * (qs & qm[il/2][3]);
+                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
+            }
+            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
+            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
+            sumf1[row] += d1 * (scales[1] - 32);
+            sumf2[row] += d2 * (scales[3] - 32);
+
+            q  += args.nb01/2;
+            h  += args.nb01/2;
+            a  += args.nb01/2;
+            dh += args.nb01/2;
+        }
+
+        y1 += 4 * QK_K;
+    }
+
+    for (int row = 0; row < 2; ++row) {
+        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
+        sumf1[row] = simd_sum(sumf);
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    if (tiisg == 0) {
+        for (int row = 0; row < 2; ++row) {
+            dst_f32[first_row + row] = sumf1[row];
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q4_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int ix = tiisg/8;  // 0...3
+    const int it = tiisg%8;  // 0...7
+    const int iq = it/4;     // 0 or 1
+    const int ir = it%4;     // 0...3
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = r0 * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
+
+    float yl[16];
+    float yh[16];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
+
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+    for (int ib = ix; ib < nb; ib += 4) {
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
+            yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
+            yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
+            yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
+        }
+
+        device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
+        device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
+        device const half     * dh = &x[ib].d;
+
+        for (int row = 0; row < N_DST; row++) {
+            sc16[0] = sc[0] & kmask1;
+            sc16[1] = sc[2] & kmask1;
+            sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
+            sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
+
+            device const uint16_t * q2 = q1 + 32;
+
+            float4 acc1 = {0.f, 0.f, 0.f, 0.f};
+            float4 acc2 = {0.f, 0.f, 0.f, 0.f};
+            for (int i = 0; i < 8; i += 2) {
+                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
+                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
+                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
+                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
+                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
+                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
+                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
+                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+            }
+
+            float dall = dh[0];
+            float dmin = dh[1];
+            sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
+                                 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
+                                 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
+                                 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += args.nb01/2;
+            sc += args.nb01/2;
+            dh += args.nb01/2;
+        }
+
+        y4 += 4 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_q5_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
+
+    float sumf[2]={0.f};
+
+    float yl[16], yh[16];
+
+    const uint16_t kmask1 = 0x3f3f;
+    const uint16_t kmask2 = 0x0f0f;
+    const uint16_t kmask3 = 0xc0c0;
+
+    const int tid = tiisg/4;
+    const int ix  = tiisg%4;
+    const int iq  = tid/4;
+    const int ir  = tid%4;
+    const int n   = 8;
+
+    const int l0 = n*ir;
+    const int q_offset = 32*iq + l0;
+    const int y_offset = 64*iq + l0;
+
+    const uint8_t hm1 = 1u << (2*iq);
+    const uint8_t hm2 = hm1 << 1;
+    const uint8_t hm3 = hm1 << 4;
+    const uint8_t hm4 = hm2 << 4;
+
+    uint16_t sc16[4];
+    thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
+
+    device const float * y1 = yy + ix*QK_K + y_offset;
+
+    for (int i = ix; i < nb; i += 4) {
+        device const uint8_t * q1 = x[i].qs + q_offset;
+        device const uint8_t * qh = x[i].qh + l0;
+        device const half * dh = &x[i].d;
+        device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
+
+        device const float * y2 = y1 + 128;
+        float4 sumy = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < 8; ++l) {
+            yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
+            yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
+            yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
+            yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
+        }
+
+        for (int row = 0; row < 2; ++row) {
+            device const uint8_t * q2 = q1 + 64;
+
+            sc16[0] = a[0] & kmask1;
+            sc16[1] = a[2] & kmask1;
+            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
+            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
+
+            float4 acc1 = {0.f};
+            float4 acc2 = {0.f};
+            for (int l = 0; l < n; ++l) {
+                uint8_t h = qh[l];
+                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
+                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
+                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
+                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
+                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
+                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
+                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
+                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
+            }
+            const float dall = dh[0];
+            const float dmin = dh[1];
+            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
+                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
+                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
+                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
+                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
+
+            q1 += args.nb01;
+            qh += args.nb01;
+            dh += args.nb01/2;
+            a  += args.nb01/2;
+        }
+
+        y1 += 4 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2; ++row) {
+        const float tot = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = tot;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template 
+void kernel_mul_mv_q6_K_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const uint8_t kmask1 = 0x03;
+    const uint8_t kmask2 = 0x0C;
+    const uint8_t kmask3 = 0x30;
+    const uint8_t kmask4 = 0xC0;
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int row = 2*r0 + sgitg;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =  r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
+
+    float sumf = 0;
+
+    const int tid  = tiisg/2;
+    const int ix   = tiisg%2;
+    const int ip   = tid/8;         // 0 or 1
+    const int il   = tid%8;
+    const int n    = 4;
+    const int l0   = n*il;
+    const int is   = 8*ip + l0/16;
+
+    const int y_offset = 128*ip + l0;
+    const int q_offset_l = 64*ip + l0;
+    const int q_offset_h = 32*ip + l0;
+
+    for (int i = ix; i < nb; i += 2) {
+        device const uint8_t * q1 = x[i].ql + q_offset_l;
+        device const uint8_t * q2 = q1 + 32;
+        device const uint8_t * qh = x[i].qh + q_offset_h;
+        device const int8_t  * sc = x[i].scales + is;
+
+        device const float * y = yy + i * QK_K + y_offset;
+
+        const float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+        for (int l = 0; l < n; ++l) {
+            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        }
+
+        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    const float tot = simd_sum(sumf);
+    if (tiisg == 0) {
+        dst_f32[row] = tot;
+    }
+}
+
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+// ======================= "True" 2-bit
+
+template
+void kernel_mul_mv_iq2_xxs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
+    device const float         * y = (device const float         *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_xxs * xr = x + ibl;
+        device const uint16_t * q2 = xr->qs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            device const uint8_t * aux8 = (device const uint8_t *)q2;
+            const uint32_t aux32 = q2[2] | (q2[3] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float sum = 0;
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
+                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 8; ++j) {
+                    sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * sum;
+
+            dh += args.nb01/2;
+            q2 += args.nb01/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq2_xs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 512);
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_xs * xr = x + ibl;
+        device const uint16_t * q2 = xr->qs + 4 * ib;
+        device const uint8_t  * sc = xr->scales + ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const uint8_t ls1 = sc[0] & 0xf;
+            const uint8_t ls2 = sc[0] >>  4;
+            const float d1 = db * (0.5f + ls1);
+            const float d2 = db * (0.5f + ls2);
+
+            float sum1 = 0, sum2 = 0;
+            for (int l = 0; l < 2; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+                const uint8_t signs = ssigns[(q2[l] >> 9)];
+                for (int j = 0; j < 8; ++j) {
+                    sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            for (int l = 2; l < 4; ++l) {
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+                const uint8_t signs = ssigns[(q2[l] >> 9)];
+                for (int j = 0; j < 8; ++j) {
+                    sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d1 * sum1 + d2 * sum2;
+
+            dh += args.nb01/2;
+            q2 += args.nb01/2;
+            sc += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template 
+void kernel_mul_mv_iq3_xxs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
+    device const float         * y = (device const float         *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
+    {
+        int nval = 4;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
+        nval = 2;
+        pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_xxs * xr = x + ibl;
+        device const uint8_t  * q3 = xr->qs + 8 * ib;
+        device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+            const float db = dh[0];
+            const uint32_t aux32 = gas[0] | (gas[1] << 16);
+            const float d = db * (0.5f + (aux32 >> 28));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
+                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh  += args.nb01/2;
+            q3  += args.nb01;
+            gas += args.nb01/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.5f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq3_s_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
+    {
+        int nval = 8;
+        int pos  = (32*sgitg + tiisg)*nval;
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+    }
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq3_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 8 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + (ib/2);
+        device const uint8_t * signs = xr->signs + 4 * ib;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
+
+            float2 sum = {0};
+            for (int l = 0; l < 4; ++l) {
+                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
+                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
+                for (int j = 0; j < 4; ++j) {
+                    sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
+                    sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
+                }
+            }
+            sumf[row] += d * (sum[0] + sum[1]);
+
+            dh    += args.nb01/2;
+            qs    += args.nb01;
+            qh    += args.nb01;
+            sc    += args.nb01;
+            signs += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq3_s_f32")]]
+kernel void kernel_mul_mv_iq3_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template 
+void kernel_mul_mv_iq2_s_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
+    //{
+    //    int nval = 32;
+    //    int pos  = (32*sgitg + tiisg)*nval;
+    //    for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
+    //    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //}
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq2_s * xr = x + ibl;
+        device const uint8_t * qs = xr->qs + 4 * ib;
+        device const uint8_t * qh = xr->qh + ib;
+        device const uint8_t * sc = xr->scales + ib;
+        device const uint8_t * signs = qs + QK_K/8;
+        device const half * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            const float db = dh[0];
+            const float d1 = db * (0.5f + (sc[0] & 0xf));
+            const float d2 = db * (0.5f + (sc[0] >>  4));
+
+            float2 sum = {0};
+            for (int l = 0; l < 2; ++l) {
+                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
+                    sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
+                }
+            }
+            sumf[row] += d1 * sum[0] + d2 * sum[1];
+
+            dh    += args.nb01/2;
+            qs    += args.nb01;
+            qh    += args.nb01;
+            sc    += args.nb01;
+            signs += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum * 0.25f;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_s_f32")]]
+kernel void kernel_mul_mv_iq2_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+void kernel_mul_mv_iq1_s_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float sumy = 0;
+        for (int i = 0; i < 32; ++i) {
+            yl[i] = y4[i];
+            sumy += yl[i];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_s * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint16_t * qh = xr->qh + ib;
+        device const half     * dh = &xr->d;
+
+        for (int row = 0; row < N_DST; row++) {
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
+
+            float sum = 0;
+            for (int j = 0; j < 4; ++j) {
+                sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                     + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+                     + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                     + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
+
+            dh += args.nb01/2;
+            qs += args.nb01;
+            qh += args.nb01/2;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+template 
+void kernel_mul_mv_iq1_m_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    float yl[32];
+    float sumf[N_DST]={0.f}, all_sum;
+
+    const int nb32 = nb * (QK_K / 32);
+
+    const int ix = tiisg;
+
+    device const float * y4 = y + 32 * ix;
+
+    iq1m_scale_t scale;
+
+    for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+        float4 sumy = {0.f};
+        for (int i = 0; i < 8; ++i) {
+            yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
+            yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
+            yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
+            yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
+        }
+
+        const int ibl = ib32 / (QK_K / 32);
+        const int ib  = ib32 % (QK_K / 32);
+
+        device const block_iq1_m * xr = x + ibl;
+        device const uint8_t  * qs = xr->qs + 4 * ib;
+        device const uint8_t  * qh = xr->qh + 2 * ib;
+        device const uint16_t * sc = (device const uint16_t *)xr->scales;
+
+        for (int row = 0; row < N_DST; row++) {
+            scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+
+            constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+            constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+            constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
+            constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
+
+            float2 sum = {0.f};
+            for (int j = 0; j < 4; ++j) {
+                sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+                        + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
+                sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
+                        + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
+            }
+            const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+            const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+
+            sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
+                                             (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
+
+            sc += args.nb01/2;
+            qs += args.nb01;
+            qh += args.nb01;
+        }
+
+        y4 += 32 * 32;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < N_DST; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+template
+void kernel_mul_mv_iq4_nl_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK4_NL;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    const int ix = tiisg/2;  // 0...15
+    const int it = tiisg%2;  // 0 or 1
+
+    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK4_NL + it * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ib = ix; ib < nb; ib += 16) {
+
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+
+            device const block_iq4_nl & xb = x[row*nb + ib];
+            device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = q4[0] | (q4[1] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = q4[2] | (q4[3] << 16);
+            aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
+            aux32[0] &= 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 16 * QK4_NL;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+template
+void kernel_mul_mv_iq4_xs_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK_K;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
+
+    const int ix = tiisg/16;  // 0 or 1
+    const int it = tiisg%16;  // 0...15
+    const int ib = it/2;
+    const int il = it%2;
+
+    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[2]={0.f}, all_sum;
+
+    device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
+
+    uint32_t aux32[2];
+    thread const uint8_t * q8 = (thread const uint8_t *)aux32;
+
+    float4 qf1, qf2;
+
+    for (int ibl = ix; ibl < nb; ibl += 2) {
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+
+        for (int row = 0; row < 2; ++row) {
+            device const block_iq4_xs & xb = x[row*nb + ibl];
+            device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
+
+            float4 acc1 = {0.f}, acc2 = {0.f};
+
+            aux32[0] = (q4[0]     ) & 0x0f0f0f0f;
+            aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[0] * qf1;
+            acc2 += yl[1] * qf2;
+
+            aux32[0] = (q4[1]     ) & 0x0f0f0f0f;
+            aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
+            acc1 += yl[2] * qf1;
+            acc2 += yl[3] * qf2;
+
+            acc1 += acc2;
+
+            const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
+            sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
+
+        }
+
+        yb += 2 * QK_K;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2; ++row) {
+        all_sum = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = all_sum;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+[[host_name("kernel_mul_mv_iq4_xs_f32")]]
+kernel void kernel_mul_mv_iq4_xs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template
+kernel void kernel_get_rows_q(
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+        float4x4 temp;
+        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+template
+kernel void kernel_get_rows_f(
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
+        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
+    }
+}
+
+kernel void kernel_get_rows_i32(
+        device const  void * src0,
+        device const  void * src1,
+        device     int32_t * dst,
+        constant   int64_t & ne00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb1,
+        constant  uint64_t & nb2,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int64_t i10 = tgpig.x;
+    const int64_t i11 = tgpig.y;
+
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    const int64_t i02 = i11;
+
+    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
+        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+    }
+}
+
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template
+kernel void kernel_mul_mm(
+        constant ggml_metal_kargs_mul_mm & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup T     * sa = (threadgroup T     *)(shmem);
+    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+    const int r0 = tgpig.y;
+    const int r1 = tgpig.x;
+    const int im = tgpig.z;
+
+    // if this block is of 64x32 shape or smaller
+    const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_T8x8     ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 mc[8];
+
+    for (short i = 0; i < 8; i++){
+        mc[i] = make_filled_simdgroup_matrix(0.f);
+    }
+
+    short il = (tiitg % THREAD_PER_ROW);
+
+    const int i12 = im%args.ne12;
+    const int i13 = im/args.ne12;
+
+    const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const short    offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0
+        + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
+
+    device const float   * y = (device const float   *)(src1
+        + args.nb13*i13
+        + args.nb12*i12
+        + args.nb11*(r1*BLOCK_SIZE_N + thread_col)
+        + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        T4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        #pragma unroll(16)
+        for (short i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+            +                     (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+            +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2 + nl - 1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup const T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
+        threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
+
+        #pragma unroll(4)
+        for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
+            #pragma unroll(4)
+            for (short i = 0; i < 4; i++) {
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+            }
+
+            simdgroup_barrier(mem_flags::mem_none);
+
+            #pragma unroll(2)
+            for (short i = 0; i < 2; i++) {
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+            }
+
+            #pragma unroll(8)
+            for (short i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
+            }
+
+            lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
+            lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
+        }
+    }
+
+    if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
+        device float * C = (device float *) dst +
+            (BLOCK_SIZE_M * r0 + 32*(sgitg &  1)) + \
+            (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
+
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
+        }
+    } else {
+        // block is smaller than 64x32, we should avoid writing data outside of the matrix
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *) shmem) \
+                                     + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (sgitg == 0) {
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
+                device float4 * D4 = (device float4 *) D;
+
+                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
+                threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+                int i = 0;
+                for (; i < n_rows/4; i++) {
+                    *(D4 + i) = *(C4 + i);
+                }
+
+                i *= 4;
+                for (; i < n_rows; i++) {
+                    *(D + i) = *(C + i);
+                }
+            }
+        }
+    }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+// TODO: this kernel needs to be reimplemented from scratch for better performance
+template
+void kernel_mul_mm_id_impl(
+        int32_t  ne00,
+        int32_t  ne02,
+        uint64_t nb01,
+        uint64_t nb02,
+        int32_t  ne11,
+        int32_t  ne12,
+        uint64_t nb10,
+        uint64_t nb11,
+        uint64_t nb12,
+        int32_t  ne0,
+        int32_t  ne1,
+        int64_t  ne0ne1,
+        device   const char * src0,
+        device   const char * src1,
+        threadgroup ushort2 * rowids,
+        device         char * dst,
+        threadgroup    char * shmem,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shmem);
+    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+    const int r0 = tgpig.y;
+    const int r1 = tgpig.x;
+
+    if (r1*BLOCK_SIZE_N >= ne1) return;
+
+    // if this block is of 64x32 shape or smaller
+    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+    // a thread shouldn't load data outside of the matrix
+    short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+    short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+    simdgroup_half8x8  ma[4];
+    simdgroup_float8x8 mb[2];
+    simdgroup_float8x8 mc[8];
+    for (int i = 0; i < 8; i++){
+        mc[i] = make_filled_simdgroup_matrix(0.f);
+    }
+    short il = (tiitg % THREAD_PER_ROW);
+
+    ushort offset1 = il/nl;
+
+    threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * id[1]
+        + nb11 * (id[0] % ne11)
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+        // load data and store to threadgroup memory
+        half4x4 temp_a;
+        dequantize_func(x, il, temp_a);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        for (int i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        }
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+        il = (il + 2 < nl) ? il + 2 : il % 2;
+        x  = (il < 2) ? x + (2+nl-1)/nl : x;
+        y += BLOCK_SIZE_K;
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        // load matrices from threadgroup memory and conduct outer products
+        threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+        #pragma unroll(BLOCK_SIZE_K/8)
+        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            #pragma unroll(4)
+            for (int i = 0; i < 4; i++) {
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
+            }
+            simdgroup_barrier(mem_flags::mem_none);
+            #pragma unroll(2)
+            for (int i = 0; i < 2; i++) {
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
+            }
+
+            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+            #pragma unroll(8)
+            for (int i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
+            }
+        }
+    }
+
+    {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup float * temp_str = ((threadgroup float *) shmem) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+        for (int i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (sgitg == 0) {
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+                int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
+
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + joff;
+                device float4 * D4 = (device float4 *) D;
+
+                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
+                threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+                int i = 0;
+                for (; i < n_rows/4; i++) {
+                    *(D4 + i) = *(C4 + i);
+                }
+
+                i *= 4;
+                for (; i < n_rows; i++) {
+                    *(D + i) = *(C + i);
+                }
+            }
+        }
+    }
+}
+
+template
+kernel void kernel_mul_mm_id(
+        constant ggml_metal_kargs_mul_mm_id & args,
+        device const char * src0s,
+        device const char * src1,
+        device       char * dst,
+        device const char * ids,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    const int32_t i02 = tgpig.z;
+
+    tgpig.z = 0;
+
+    device const char * src0 = src0s + i02*args.nb02;
+
+    // row indices
+    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
+
+    // TODO: parallelize this loop
+    int32_t _ne1 = 0;
+    for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
+        for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
+            int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
+            if (id == i02) {
+                if (tiitg == 0) {
+                    rowids[_ne1] = ushort2(ii0, ii1);
+                }
+                _ne1++;
+            }
+        }
+    }
+
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    kernel_mul_mm_id_impl(
+        args.ne00,
+        args.ne02,
+        args.nb01,
+        args.nb02,
+        args.ne11,
+        args.ne12,
+        args.nb10,
+        args.nb11,
+        args.nb12,
+        args.ne0,
+        _ne1,
+        (int64_t)args.ne0*args.ne1,
+        src0,
+        src1,
+        rowids,
+        dst,
+        shmem,
+        tgpig,
+        tiitg,
+        sgitg);
+}
+
+#define QK_NL 16
+
+//
+// get rows
+//
+
+typedef decltype(kernel_get_rows_f) get_rows_f_t;
+
+template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f;
+template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f;
+#endif
+
+typedef decltype(kernel_get_rows_q) get_rows_q_t;
+
+template [[host_name("kernel_get_rows_q4_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q5_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_q6_K")]]    kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q;
+template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q;
+
+//
+// matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm) mat_mm_t;
+
+template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mat_mm_t kernel_mul_mm;
+#endif
+template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm;
+
+//
+// indirect matrix-matrix multiplication
+//
+
+typedef decltype(kernel_mul_mm_id) mat_mm_id_t;
+
+template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+#endif
+template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q8_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q2_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q3_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q4_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q5_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_q6_K_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq3_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq1_s_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq1_m_f32")]]   kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel_mul_mm_id;
+
+//
+// matrix-vector multiplication
+//
+
+typedef void (kernel_mul_mv_impl_t)(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg);
+
+typedef void (kernel_mul_mv2_impl_t)(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg);
+
+template
+void mmv_fn(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiitg,
+        ushort tiisg,
+        ushort sgitg) {
+    impl_fn(args, src0, src1, dst, tgpig, tiisg);
+}
+
+template
+void mmv_fn(
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiitg,
+        ushort tiisg,
+        ushort sgitg) {
+    impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+typedef decltype(mmv_fn>) mul_mv_impl_fn_t;
+
+template
+kernel void kernel_mul_mv_id(
+        constant ggml_metal_kargs_mul_mv_id & args,
+        device const char * src0s,
+        device const char * src1,
+        device       char * dst,
+        device const char * ids,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const int iid1 = tgpig.z/args.nei0;
+    const int idx  = tgpig.z%args.nei0;
+
+    tgpig.z = 0;
+
+    const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
+
+    const int64_t i11 = idx % args.ne11;
+    const int64_t i12 = iid1;
+
+    const int64_t i1 = idx;
+    const int64_t i2 = i12;
+
+    device const char * src0_cur = src0s + i02*args.nb02;
+    device const char * src1_cur = src1  + i11*args.nb11 + i12*args.nb12;
+
+    device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
+
+    ggml_metal_kargs_mul_mv args0 = {
+        /*.ne00 =*/ args.ne00,
+        /*.ne01 =*/ args.ne01,
+        /*.ne02 =*/ 1, // args.ne02,
+        /*.nb00 =*/ args.nb00,
+        /*.nb01 =*/ args.nb01,
+        /*.nb02 =*/ args.nb02,
+        /*.nb03 =*/ args.nb02, // args.ne02 == 1
+        /*.ne10 =*/ args.ne10,
+        /*.ne11 =*/ 1, // args.ne11,
+        /*.ne12 =*/ 1, // args.ne12,
+        /*.nb10 =*/ args.nb10,
+        /*.nb11 =*/ args.nb11,
+        /*.nb12 =*/ args.nb12,
+        /*.nb13 =*/ args.nb12, // ne12 == 1
+        /*.ne0  =*/ args.ne0,
+        /*.ne1  =*/ 1, // args.ne1,
+        /*.r2   =*/ 1,
+        /*.r3   =*/ 1,
+    };
+
+    impl_fn(
+        args0,
+        /* src0 */ src0_cur,
+        /* src1 */ src1_cur,
+        /* dst  */ dst_cur,
+        shmem,
+        tgpig,
+        tiitg,
+        tiisg,
+        sgitg);
+}
+
+typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t;
+
+template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+#endif
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+
+kernel void kernel_pool_2d_max_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+
+    float res = -INFINITY;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            res = MAX(res, i_ptr[i * IW + j]);
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
+
+kernel void kernel_pool_2d_avg_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+    // const float scale = 1. / ((eh - bh) * (ew - bw));
+    const float scale = 1. / (k0 * k1);
+
+    float res = 0;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            float cur = i_ptr[i * IW + j];
+            res += cur * scale;
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
diff --git a/llama/ggml-metal_darwin_arm64.m b/llama/ggml-metal_darwin_arm64.m
new file mode 100644
index 000000000..d72129c3a
--- /dev/null
+++ b/llama/ggml-metal_darwin_arm64.m
@@ -0,0 +1,4943 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#import "ggml-metal.h"
+
+#import "ggml-impl.h"
+#import "ggml-backend-impl.h"
+#import "ggml-metal-impl.h"
+
+#import 
+
+#import 
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+// max memory buffers that can be mapped to the device
+#define GGML_METAL_MAX_BUFFERS 64
+
+// max number of MTLCommandBuffer used to submit a graph for processing
+#define GGML_METAL_MAX_COMMAND_BUFFERS 8
+
+#define UNUSED(x) (void)(x)
+
+// globals
+
+// overload of MTLGPUFamilyMetal3 (not available in some environments)
+static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
+
+// initialized in ggml_backend_metal_reg
+static struct ggml_backend_reg    g_ggml_backend_metal_reg;
+static struct ggml_backend_device g_ggml_backend_metal_device;
+
+// information about a Metal device
+// note: assumes single GPU device - the default one
+// TODO: support multiple GPU devices
+static struct ggml_backend_metal_device_context {
+    id mtl_device;
+    int           mtl_device_ref_count;
+
+    bool has_simdgroup_reduction;
+    bool has_simdgroup_mm;
+    bool has_bfloat;
+    bool use_bfloat;
+
+    char name[128];
+} g_ggml_ctx_dev_main = {
+    /*.mtl_device              =*/ nil,
+    /*.mtl_device_ref_count    =*/ 0,
+    /*.has_simdgroup_reduction =*/ false,
+    /*.has_simdgroup_mm        =*/ false,
+    /*.has_bfloat              =*/ false,
+    /*.use_bfloat              =*/ false,
+    /*.name                    =*/ "",
+};
+
+// acquire
+static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
+    assert(ctx != NULL);
+
+    if (ctx->mtl_device == nil) {
+        ctx->mtl_device = MTLCreateSystemDefaultDevice();
+
+        ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+        ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+
+        ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+
+        ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+        ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
+
+#if defined(GGML_METAL_USE_BF16)
+        ctx->use_bfloat = ctx->has_bfloat;
+#else
+        ctx->use_bfloat = false;
+#endif
+
+        strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
+    }
+
+    ctx->mtl_device_ref_count++;
+
+    return ctx->mtl_device;
+}
+
+// release
+static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
+    assert(ctx != NULL);
+    assert(ctx->mtl_device_ref_count > 0);
+
+    ctx->mtl_device_ref_count--;
+
+    if (ctx->mtl_device_ref_count == 0) {
+        [ctx->mtl_device release];
+        ctx->mtl_device = nil;
+    }
+}
+
+// kernels
+
+struct ggml_metal_kernel {
+    id pipeline;
+};
+
+enum ggml_metal_kernel_type {
+    GGML_METAL_KERNEL_TYPE_ADD,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_SUB,
+    GGML_METAL_KERNEL_TYPE_SUB_ROW,
+    GGML_METAL_KERNEL_TYPE_MUL,
+    GGML_METAL_KERNEL_TYPE_MUL_ROW,
+    GGML_METAL_KERNEL_TYPE_DIV,
+    GGML_METAL_KERNEL_TYPE_DIV_ROW,
+    GGML_METAL_KERNEL_TYPE_REPEAT_F32,
+    GGML_METAL_KERNEL_TYPE_REPEAT_F16,
+    GGML_METAL_KERNEL_TYPE_REPEAT_I32,
+    GGML_METAL_KERNEL_TYPE_REPEAT_I16,
+    GGML_METAL_KERNEL_TYPE_SCALE,
+    GGML_METAL_KERNEL_TYPE_SCALE_4,
+    GGML_METAL_KERNEL_TYPE_CLAMP,
+    GGML_METAL_KERNEL_TYPE_TANH,
+    GGML_METAL_KERNEL_TYPE_RELU,
+    GGML_METAL_KERNEL_TYPE_SIGMOID,
+    GGML_METAL_KERNEL_TYPE_GELU,
+    GGML_METAL_KERNEL_TYPE_GELU_4,
+    GGML_METAL_KERNEL_TYPE_GELU_QUICK,
+    GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
+    GGML_METAL_KERNEL_TYPE_SILU,
+    GGML_METAL_KERNEL_TYPE_SILU_4,
+    GGML_METAL_KERNEL_TYPE_ELU,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
+    GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
+    GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+    GGML_METAL_KERNEL_TYPE_RMS_NORM,
+    GGML_METAL_KERNEL_TYPE_GROUP_NORM,
+    GGML_METAL_KERNEL_TYPE_NORM,
+    GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
+    GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
+  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
+  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
+  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
+    GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+    GGML_METAL_KERNEL_TYPE_IM2COL_F32,
+    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
+    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
+    GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
+    GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
+    GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
+    GGML_METAL_KERNEL_TYPE_PAD_F32,
+    GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
+    GGML_METAL_KERNEL_TYPE_UNPAD_F32,
+    GGML_METAL_KERNEL_TYPE_ARANGE_F32,
+    GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
+    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
+    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
+    GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_SET_I32,
+    GGML_METAL_KERNEL_TYPE_SET_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
+    GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
+    GGML_METAL_KERNEL_TYPE_CONCAT,
+    GGML_METAL_KERNEL_TYPE_SQR,
+    GGML_METAL_KERNEL_TYPE_SQRT,
+    GGML_METAL_KERNEL_TYPE_SIN,
+    GGML_METAL_KERNEL_TYPE_COS,
+    GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+    GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
+    GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
+    GGML_METAL_KERNEL_TYPE_ARGMAX,
+
+    GGML_METAL_KERNEL_TYPE_COUNT
+};
+
+struct ggml_backend_metal_context {
+    id queue;
+
+    dispatch_queue_t d_queue;
+
+    struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
+
+    // capture state
+    bool capture_next_compute;
+    bool capture_started;
+
+    id capture_scope;
+
+    // command buffer state
+    int n_cb;           // number of extra threads used to submit the command buffers
+    int n_nodes_0;      // number of nodes submitted by the main thread
+    int n_nodes_1;      // remaining number of nodes submitted by the n_cb threads
+    int n_nodes_per_cb;
+
+    struct ggml_cgraph * gf;
+
+    // the callback given to the thread pool
+    void (^encode_async)(size_t ith);
+
+    // n_cb command buffers + 1 used by the main thread
+    id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
+
+    // abort ggml_metal_graph_compute if callback returns true
+    ggml_abort_callback abort_callback;
+    void *              abort_callback_data;
+};
+
+// MSL code
+// TODO: move the contents here when ready
+//       for now it is easier to work in a separate file
+// static NSString * const msl_library_source = @"see metal.metal";
+
+// Here to assist with NSBundle Path Hack
+@interface GGMLMetalClass : NSObject
+@end
+@implementation GGMLMetalClass
+@end
+
+static void * ggml_metal_host_malloc(size_t n) {
+    void * data = NULL;
+
+#if TARGET_OS_OSX
+    kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
+    if (err != KERN_SUCCESS) {
+        GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
+        return NULL;
+    }
+#else
+    const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
+    if (result != 0) {
+        GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
+        return NULL;
+    }
+#endif
+
+    return data;
+}
+
+static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
+    GGML_LOG_INFO("%s: allocating\n", __func__);
+
+#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
+    // Show all the Metal device instances in the system
+    NSArray * devices = MTLCopyAllDevices();
+    for (id device in devices) {
+        GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
+    }
+    [devices release]; // since it was created by a *Copy* C method
+#endif
+
+    // init context
+    struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
+    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+    id device = ggml_backend_metal_device_acq(ctx_dev);
+    GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
+
+    ctx->queue  = [device newCommandQueue];
+    ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
+
+    id metal_library;
+
+    // load library
+    //
+    // - first check if the library is embedded
+    // - then check if the library is in the bundle
+    // - if not found, load the source and compile it
+    // - if that fails, return NULL
+    {
+        NSBundle * bundle = nil;
+#ifdef SWIFT_PACKAGE
+        bundle = SWIFTPM_MODULE_BUNDLE;
+#else
+        bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+#endif
+
+        NSError * error = nil;
+
+#if GGML_METAL_EMBED_LIBRARY
+        const bool try_metallib = false;
+#else
+        const bool try_metallib = true;
+#endif
+
+        NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
+        if (path_lib == nil) {
+            // Try to find the resource in the directory where the current binary located.
+            NSString * current_binary = [[NSProcessInfo processInfo] arguments][0];
+            NSString * bin_dir = [current_binary stringByDeletingLastPathComponent];
+            NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
+            if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
+                GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]);
+                NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error];
+                if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
+                    // Optionally, if this is a symlink, try to resolve it.
+                    default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error];
+                    if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) {
+                        // It is a relative path, adding the binary directory as directory prefix.
+                        default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]];
+                    }
+                    if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) {
+                        // Link to the resource could not be resolved.
+                        default_metallib_path = nil;
+                    } else {
+                        GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]);
+                    }
+                }
+            } else {
+                // The resource couldn't be found in the binary's directory.
+                default_metallib_path = nil;
+            }
+            path_lib = default_metallib_path;
+        }
+
+        if (try_metallib && path_lib != nil) {
+            // pre-compiled library found
+            NSURL * libURL = [NSURL fileURLWithPath:path_lib];
+            GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
+
+            metal_library = [device newLibraryWithURL:libURL error:&error];
+            if (error) {
+                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+                return NULL;
+            }
+        } else {
+#if GGML_METAL_EMBED_LIBRARY
+            GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
+
+            extern const char ggml_metallib_start[];
+            extern const char ggml_metallib_end[];
+
+            NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
+#else
+            GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
+
+            NSString * path_source;
+            NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
+
+            GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
+
+            if (path_resource) {
+                path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
+            } else {
+                path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            }
+
+            if (path_source == nil) {
+                GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
+                path_source = @"ggml-metal.metal";
+            }
+
+            GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
+
+            NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
+            if (error) {
+                GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+                return NULL;
+            }
+#endif // GGML_METAL_EMBED_LIBRARY
+
+            @autoreleasepool {
+                // dictionary of preprocessor macros
+                NSMutableDictionary * prep = [NSMutableDictionary dictionary];
+
+                if (ctx_dev->use_bfloat) {
+                    [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"];
+                }
+
+#if GGML_METAL_EMBED_LIBRARY
+                [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
+#endif
+
+                MTLCompileOptions * options = [MTLCompileOptions new];
+                options.preprocessorMacros = prep;
+
+                //[options setFastMathEnabled:false];
+
+                metal_library = [device newLibraryWithSource:src options:options error:&error];
+                if (error) {
+                    GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+                    return NULL;
+                }
+
+#if !__has_feature(objc_arc)
+                [options release];
+#endif
+            }
+#if GGML_METAL_EMBED_LIBRARY
+            [src release];
+#endif // GGML_METAL_EMBED_LIBRARY
+        }
+    }
+
+    // print MTL GPU family:
+    GGML_LOG_INFO("%s: GPU name:   %s\n", __func__, [[device name] UTF8String]);
+
+    // determine max supported GPU family
+    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+    {
+        for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+            if ([device supportsFamily:i]) {
+                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d  (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+                break;
+            }
+        }
+
+        for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
+            if ([device supportsFamily:i]) {
+                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
+                break;
+            }
+        }
+
+        for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
+            if ([device supportsFamily:i]) {
+                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
+                break;
+            }
+        }
+    }
+
+    GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false");
+    GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false");
+    GGML_LOG_INFO("%s: has bfloat            = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
+    GGML_LOG_INFO("%s: use bfloat            = %s\n", __func__, ctx_dev->use_bfloat                  ? "true" : "false");
+    GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
+
+    ctx->capture_next_compute = false;
+    ctx->capture_started = false;
+    ctx->capture_scope = nil;
+
+    ctx->gf = nil;
+    ctx->encode_async = nil;
+    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
+        ctx->command_buffers[i] = nil;
+    }
+
+#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
+    if (@available(macOS 10.12, iOS 16.0, *)) {
+        GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
+    }
+#endif
+
+    // load kernels
+    {
+        NSError * error = nil;
+
+        for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
+            ctx->kernels[i].pipeline = nil;
+        }
+
+#define GGML_METAL_ADD_KERNEL(e, name, supported) \
+        if (supported) { \
+            struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
+            id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
+            kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
+            GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+                    (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+                    (int) kernel->pipeline.threadExecutionWidth); \
+            [metal_function release]; \
+            if (error) { \
+                GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+                [metal_library release]; \
+                return NULL; \
+            } \
+        } else { \
+            GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
+        }
+
+        const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm;
+        const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
+        const bool use_bfloat              = ctx_dev->use_bfloat;
+
+        // simd_sum and simd_max requires MTLGPUFamilyApple7
+
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                       add_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB,                           sub,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW,                       sub_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                           mul,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                       mul_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                           div,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                       div_row,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                    repeat_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                    repeat_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                    repeat_i32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16,                    repeat_i16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                         scale,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                          tanh,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                          relu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                       sigmoid,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                          gelu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                        gelu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                    gelu_quick,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU,                           elu,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                 get_rows_bf16,                  use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                 get_rows_q5_1,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                 get_rows_q8_0,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                 get_rows_q2_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                 get_rows_q3_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                 get_rows_q4_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,                 get_rows_q5_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,                 get_rows_q6_K,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,              get_rows_iq2_xxs,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,               get_rows_iq2_xs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,              get_rows_iq3_xxs,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,                get_rows_iq3_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,                get_rows_iq2_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,                get_rows_iq1_s,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M,                get_rows_iq1_m,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,            mul_mv_bf16_f32_l4,             has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,              mul_mv_bf16_bf16,               has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,       mul_mv_ext_f16_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,       mul_mv_ext_f16_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,       mul_mv_ext_f16_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,       mul_mv_ext_f16_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,      mul_mv_ext_q4_0_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,      mul_mv_ext_q4_0_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,      mul_mv_ext_q4_0_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,      mul_mv_ext_q4_0_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,      mul_mv_ext_q4_1_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,      mul_mv_ext_q4_1_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,      mul_mv_ext_q4_1_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,      mul_mv_ext_q4_1_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,      mul_mv_ext_q5_0_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,      mul_mv_ext_q5_0_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,      mul_mv_ext_q5_0_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,      mul_mv_ext_q5_0_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,      mul_mv_ext_q5_1_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,      mul_mv_ext_q5_1_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,      mul_mv_ext_q5_1_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,      mul_mv_ext_q5_1_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,      mul_mv_ext_q8_0_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,      mul_mv_ext_q8_0_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,      mul_mv_ext_q8_0_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,      mul_mv_ext_q8_0_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,      mul_mv_ext_q4_K_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,      mul_mv_ext_q4_K_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,      mul_mv_ext_q4_K_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,      mul_mv_ext_q4_K_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,      mul_mv_ext_q5_K_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,      mul_mv_ext_q5_K_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,      mul_mv_ext_q5_K_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,      mul_mv_ext_q5_K_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,      mul_mv_ext_q6_K_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,      mul_mv_ext_q6_K_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,      mul_mv_ext_q6_K_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,      mul_mv_ext_q6_K_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,    mul_mv_ext_iq4_nl_f32_r1_2,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,    mul_mv_ext_iq4_nl_f32_r1_3,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,    mul_mv_ext_iq4_nl_f32_r1_4,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,    mul_mv_ext_iq4_nl_f32_r1_5,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              has_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         has_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           has_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,            mul_mv_id_bf16_f32,             has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,               mul_mm_bf16_f32,                has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,            mul_mm_id_bf16_f32,             has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,     conv_transpose_1d_f32_f32,      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,     conv_transpose_1d_f16_f32,      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UNPAD_F32,                     unpad_f32,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,      flash_attn_ext_q4_0_h112,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,      flash_attn_ext_q4_0_h128,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,      flash_attn_ext_q4_0_h256,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,       flash_attn_ext_q4_1_h64,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,       flash_attn_ext_q4_1_h80,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,       flash_attn_ext_q4_1_h96,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,      flash_attn_ext_q4_1_h112,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,      flash_attn_ext_q4_1_h128,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,      flash_attn_ext_q4_1_h256,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,       flash_attn_ext_q5_0_h64,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,       flash_attn_ext_q5_0_h80,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,       flash_attn_ext_q5_0_h96,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,      flash_attn_ext_q5_0_h112,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,      flash_attn_ext_q5_0_h128,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,      flash_attn_ext_q5_0_h256,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,       flash_attn_ext_q5_1_h64,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,       flash_attn_ext_q5_1_h80,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,       flash_attn_ext_q5_1_h96,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,      flash_attn_ext_q5_1_h112,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,      flash_attn_ext_q5_1_h128,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,      flash_attn_ext_q5_1_h256,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,       flash_attn_ext_q8_0_h64,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,       flash_attn_ext_q8_0_h80,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,       flash_attn_ext_q8_0_h96,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,      flash_attn_ext_q8_0_h112,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                       set_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                       set_i32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                  cpy_bf16_f32,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                 cpy_bf16_bf16,                  use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                        argmax,                         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true);
+    }
+
+    [metal_library release];
+
+    return ctx;
+}
+
+static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
+    GGML_LOG_INFO("%s: deallocating\n", __func__);
+
+    for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
+        [ctx->kernels[i].pipeline release];
+    }
+
+    Block_release(ctx->encode_async);
+
+    [ctx->queue release];
+
+    dispatch_release(ctx->d_queue);
+
+    free(ctx);
+}
+
+// temporarily defined here for compatibility between ggml-backend and the old API
+
+struct ggml_backend_metal_buffer {
+    void   * data;
+    size_t   size;
+
+    id metal;
+};
+
+struct ggml_backend_metal_buffer_context {
+    void * all_data;
+    size_t all_size;
+    bool owned;
+
+    // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
+    int n_buffers;
+    struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+};
+
+// finds the Metal buffer that contains the tensor data on the GPU device
+// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
+// Metal buffer based on the host memory pointer
+//
+static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) {
+    //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+
+    const int64_t tsize = ggml_nbytes(t);
+
+    ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
+
+    struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context;
+
+    // find the view that contains the tensor fully
+    for (int i = 0; i < buf_ctx->n_buffers; ++i) {
+        const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data;
+
+        //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size);
+        if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) {
+            *offs = (size_t) ioffs;
+
+            //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
+
+            return buf_ctx->buffers[i].metal;
+        }
+    }
+
+    GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
+
+    return nil;
+}
+
+static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
+    const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm;
+    const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction;
+    const bool use_bfloat              = ctx_dev->use_bfloat;
+
+    if (!use_bfloat) {
+        for (size_t i = 0, n = 3; i < n; ++i) {
+            if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
+                return false;
+            }
+        }
+    }
+
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_ELU:
+                    return ggml_is_contiguous(op->src[0]);
+                default:
+                    return false;
+            }
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_CONCAT:
+        case GGML_OP_ADD:
+        case GGML_OP_SUB:
+        case GGML_OP_ACC:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_REPEAT:
+        case GGML_OP_SCALE:
+        case GGML_OP_CLAMP:
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            return true;
+        case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+            return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_GROUP_NORM:
+            return has_simdgroup_reduction;
+        case GGML_OP_RMS_NORM:
+            return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
+        case GGML_OP_ARGMAX:
+        case GGML_OP_NORM:
+            return true;
+        case GGML_OP_ROPE:
+            {
+                const int mode = ((const int32_t *) op->op_params)[2];
+                if (mode & GGML_ROPE_TYPE_MROPE) {
+                    return false;
+                }
+                if (mode & GGML_ROPE_TYPE_VISION) {
+                    return false;
+                }
+                return true;
+            }
+        case GGML_OP_IM2COL:
+            return op->src[0]->type == GGML_TYPE_F16;
+        case GGML_OP_POOL_1D:
+            return false;
+        case GGML_OP_POOL_2D:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_PAD_REFLECT_1D:
+        case GGML_OP_UNPAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_LEAKY_RELU:
+            return true;
+        case GGML_OP_FLASH_ATTN_EXT:
+            if (op->src[1]->type != op->src[2]->type) {
+                return false;
+            }
+            return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
+        case GGML_OP_SSM_CONV:
+        case GGML_OP_SSM_SCAN:
+            return true;
+        case GGML_OP_MUL_MAT:
+        case GGML_OP_MUL_MAT_ID:
+            return has_simdgroup_reduction &&
+                (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F32:
+                        switch (op->type) {
+                           case GGML_TYPE_F32:
+                           case GGML_TYPE_F16:
+                           case GGML_TYPE_BF16:
+                           case GGML_TYPE_Q8_0:
+                           case GGML_TYPE_Q4_0:
+                           case GGML_TYPE_Q4_1:
+                           case GGML_TYPE_Q5_0:
+                           case GGML_TYPE_Q5_1:
+                           case GGML_TYPE_IQ4_NL:
+                                return true;
+                           default:
+                                return false;
+                        }
+                    case GGML_TYPE_F16:
+                        switch (op->type) {
+                            case GGML_TYPE_F32:
+                            case GGML_TYPE_F16:
+                                return true;
+                            default:
+                                return false;
+                        }
+                    case GGML_TYPE_BF16:
+                        switch (op->type) {
+                            case GGML_TYPE_F32:
+                            case GGML_TYPE_BF16:
+                                return true;
+                            default:
+                                return false;
+                        }
+                    default:
+                        return false;
+                };
+            }
+        case GGML_OP_SET:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_I32:
+                        return true;
+                    default:
+                        return false;
+                };
+            }
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_GET_ROWS:
+            {
+                return op->ne[3] == 1;
+            }
+        default:
+            return false;
+    }
+}
+
+static void ggml_metal_encode_node(
+                        ggml_backend_t   backend,
+                                   int   idx,
+          id   encoder) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    struct ggml_cgraph * gf = ctx->gf;
+
+    struct ggml_tensor * node = ggml_graph_node(gf, idx);
+
+    //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
+
+    struct ggml_tensor * src0 = node->src[0];
+    struct ggml_tensor * src1 = node->src[1];
+    struct ggml_tensor * src2 = node->src[2];
+    struct ggml_tensor * dst  = node;
+
+    if (ggml_is_empty(dst)) {
+        return;
+    }
+
+    switch (dst->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_PERMUTE:
+            {
+                // noop -> next node
+            } return;
+        default:
+            {
+            } break;
+    }
+
+    if (!ggml_metal_supports_op(ctx_dev, dst)) {
+        GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+        GGML_ABORT("unsupported op");
+    }
+
+    const int64_t  ne00 = src0 ? src0->ne[0] : 0;
+    const int64_t  ne01 = src0 ? src0->ne[1] : 0;
+    const int64_t  ne02 = src0 ? src0->ne[2] : 0;
+    const int64_t  ne03 = src0 ? src0->ne[3] : 0;
+
+    const uint64_t nb00 = src0 ? src0->nb[0] : 0;
+    const uint64_t nb01 = src0 ? src0->nb[1] : 0;
+    const uint64_t nb02 = src0 ? src0->nb[2] : 0;
+    const uint64_t nb03 = src0 ? src0->nb[3] : 0;
+
+    const int64_t  ne10 = src1 ? src1->ne[0] : 0;
+    const int64_t  ne11 = src1 ? src1->ne[1] : 0;
+    const int64_t  ne12 = src1 ? src1->ne[2] : 0;
+    const int64_t  ne13 = src1 ? src1->ne[3] : 0;
+
+    const uint64_t nb10 = src1 ? src1->nb[0] : 0;
+    const uint64_t nb11 = src1 ? src1->nb[1] : 0;
+    const uint64_t nb12 = src1 ? src1->nb[2] : 0;
+    const uint64_t nb13 = src1 ? src1->nb[3] : 0;
+
+    const int64_t  ne20 = src2 ? src2->ne[0] : 0;
+    const int64_t  ne21 = src2 ? src2->ne[1] : 0;
+    const int64_t  ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22);
+    const int64_t  ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+
+    const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
+    const uint64_t nb21 = src2 ? src2->nb[1] : 0;
+    const uint64_t nb22 = src2 ? src2->nb[2] : 0;
+    const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+
+    const int64_t  ne0  =  dst ?  dst->ne[0] : 0;
+    const int64_t  ne1  =  dst ?  dst->ne[1] : 0;
+    const int64_t  ne2  =  dst ?  dst->ne[2] : 0;
+    const int64_t  ne3  =  dst ?  dst->ne[3] : 0;
+
+    const uint64_t nb0  =  dst ?  dst->nb[0] : 0;
+    const uint64_t nb1  =  dst ?  dst->nb[1] : 0;
+    const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
+    const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
+
+    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+    const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
+
+    size_t offs_src0 = 0;
+    size_t offs_src1 = 0;
+    size_t offs_src2 = 0;
+    size_t offs_dst  = 0;
+
+    id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
+    id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
+    id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
+    id id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
+
+#if 0
+    GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+    if (src0) {
+        GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+                ggml_is_contiguous(src0), src0->name);
+    }
+    if (src1) {
+        GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+                ggml_is_contiguous(src1), src1->name);
+    }
+    if (dst) {
+        GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                dst->name);
+    }
+#endif
+
+    id device = ctx_dev->mtl_device;
+
+    switch (dst->op) {
+        case GGML_OP_CONCAT:
+            {
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+
+                const int32_t dim = ((const int32_t *) dst->op_params)[0];
+
+                ggml_metal_kargs_concat args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                    /*.dim  =*/ dim,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ADD:
+        case GGML_OP_SUB:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                const size_t offs = 0;
+
+                bool bcast_row = false;
+
+                id pipeline = nil;
+
+                if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+                    GGML_ASSERT(ggml_is_contiguous(src0));
+
+                    // src1 is a row
+                    GGML_ASSERT(ne11 == 1);
+
+                    switch (dst->op) {
+                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+
+                    bcast_row = true;
+                } else {
+                    switch (dst->op) {
+                        case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+                        case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
+                        case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+                        case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    }
+                }
+
+                ggml_metal_kargs_bin args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                    /*.offs =*/ offs,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+
+                if (bcast_row) {
+                    const int64_t n = ggml_nelements(dst)/4;
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } else {
+                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                }
+            } break;
+        case GGML_OP_REPEAT:
+            {
+                id pipeline;
+
+                switch (src0t) {
+                    case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
+                    case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
+                    case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
+                    case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                }
+
+                ggml_metal_kargs_repeat args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ACC:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
+
+                const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
+                const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
+                const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
+                const size_t offs = ((const int32_t *) dst->op_params)[3];
+
+                const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
+
+                if (!inplace) {
+                    // run a separete kernel to cpy src->dst
+                    // not sure how to avoid this
+                    // TODO: make a simpler cpy_bytes kernel
+
+                    const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
+
+                    ggml_metal_kargs_cpy args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.ne03 =*/ ne03,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.ne2  =*/ ne2,
+                        /*.ne3  =*/ ne3,
+                        /*.nb0  =*/ nb0,
+                        /*.nb1  =*/ nb1,
+                        /*.nb2  =*/ nb2,
+                        /*.nb3  =*/ nb3,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                    const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                }
+
+                const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+                ggml_metal_kargs_bin args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ pnb1,
+                    /*.nb02 =*/ pnb2,
+                    /*.nb03 =*/ pnb3,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ pnb1,
+                    /*.nb2  =*/ pnb2,
+                    /*.nb3  =*/ pnb3,
+                    /*.offs =*/ offs,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_SCALE:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                float scale;
+                memcpy(&scale, dst->op_params, sizeof(scale));
+
+                int64_t n = ggml_nelements(dst);
+
+                id pipeline = nil;
+
+                if (n % 4 == 0) {
+                    n /= 4;
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
+                } else {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
+                }
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
+                [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_CLAMP:
+            {
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
+
+                float min;
+                float max;
+                memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
+                memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&min   length:sizeof(min) atIndex:2];
+                [encoder setBytes:&max   length:sizeof(max) atIndex:3];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(node)) {
+                // we are not taking into account the strides, so for now require contiguous tensors
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                case GGML_UNARY_OP_TANH:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_RELU:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_SIGMOID:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_GELU:
+                {
+                    int64_t n = ggml_nelements(dst);
+
+                    id pipeline = nil;
+
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+                    }
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                {
+                    int64_t n = ggml_nelements(dst);
+
+                    id pipeline = nil;
+
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+                    }
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_SILU:
+                {
+                    int64_t n = ggml_nelements(dst);
+
+                    id pipeline = nil;
+
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+                    }
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                case GGML_UNARY_OP_ELU:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
+                default:
+                {
+                    GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
+                    GGML_ABORT("fatal error");
+                }
+            } break;
+        case GGML_OP_SQR:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SQRT:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SIN:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_COS:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SUM_ROWS:
+            {
+                GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SOFT_MAX:
+            {
+                GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+
+                int nth = 32; // SIMD width
+
+                id pipeline = nil;
+
+                const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+                if (ne00%4 == 0) {
+                    while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
+                        nth *= 2;
+                    }
+                    if (use_f16) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
+                    }
+                } else {
+                    while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+                        nth *= 2;
+                    }
+                    if (use_f16) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
+                    }
+                }
+
+                float scale;
+                float max_bias;
+
+                memcpy(&scale,    ((const int32_t *) dst->op_params) + 0, sizeof(scale));
+                memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
+
+                const int64_t nrows_x = ggml_nrows(src0);
+                const int64_t nrows_y = src0->ne[1];
+
+                const uint32_t n_head      = nrows_x/nrows_y;
+                const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+                const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+                // TODO: add ggml_metal_kargs struct
+                // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                if (id_src1) {
+                    [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
+                }
+                [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
+                [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
+                [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
+                [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
+                [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
+                [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
+                [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
+                [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
+                [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_DIAG_MASK_INF:
+            {
+                const int n_past = ((const int32_t *)(dst->op_params))[0];
+
+                id pipeline = nil;
+
+                if (ne00%8 == 0) {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
+                } else {
+                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
+                }
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+
+                if (ne00%8 == 0) {
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                }
+                else {
+                    [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                }
+            } break;
+        case GGML_OP_SSM_CONV:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
+                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
+                [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
+                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
+                [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
+                [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
+                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
+                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
+                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
+                [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
+                [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
+                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
+                [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_SSM_SCAN:
+            {
+                struct ggml_tensor * src3 = node->src[3];
+                struct ggml_tensor * src4 = node->src[4];
+                struct ggml_tensor * src5 = node->src[5];
+
+                GGML_ASSERT(src3);
+                GGML_ASSERT(src4);
+                GGML_ASSERT(src5);
+
+                size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
+                size_t offs_src5 = 0;
+
+                id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+                id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
+                id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
+
+                const int64_t  ne30 = src3->ne[0]; GGML_UNUSED(ne30);
+                const int64_t  ne31 = src3->ne[1]; GGML_UNUSED(ne31);
+
+                const uint64_t nb30 = src3->nb[0];
+                const uint64_t nb31 = src3->nb[1];
+
+                const int64_t  ne40 = src4->ne[0]; GGML_UNUSED(ne40);
+                const int64_t  ne41 = src4->ne[1]; GGML_UNUSED(ne41);
+                const int64_t  ne42 = src4->ne[2]; GGML_UNUSED(ne42);
+
+                const uint64_t nb40 = src4->nb[0];
+                const uint64_t nb41 = src4->nb[1];
+                const uint64_t nb42 = src4->nb[2];
+
+                const int64_t  ne50 = src5->ne[0]; GGML_UNUSED(ne50);
+                const int64_t  ne51 = src5->ne[1]; GGML_UNUSED(ne51);
+                const int64_t  ne52 = src5->ne[2]; GGML_UNUSED(ne52);
+
+                const uint64_t nb50 = src5->nb[0];
+                const uint64_t nb51 = src5->nb[1];
+                const uint64_t nb52 = src5->nb[2];
+
+                const int64_t d_state      = ne00;
+                const int64_t d_inner      = ne01;
+                const int64_t n_seq_tokens = ne11;
+                const int64_t n_seqs       = ne02;
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
+                [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
+                [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
+
+                [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
+                [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
+                [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
+                [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
+
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
+                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
+                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
+                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
+                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
+                [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
+                [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
+                [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
+                [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
+                [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
+                [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
+                [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
+                [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_MUL_MAT:
+            {
+                GGML_ASSERT(ne00 == ne10);
+
+                GGML_ASSERT(ne12 % ne02 == 0);
+                GGML_ASSERT(ne13 % ne03 == 0);
+
+                const uint32_t r2 = ne12/ne02;
+                const uint32_t r3 = ne13/ne03;
+
+                // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                // to the matrix-vector kernel
+                const int ne11_mm_min = 4;
+
+                // first try to use small-batch mat-mv kernels
+                // these should be efficient for BS [2, ~8]
+                if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+                    (
+                     (
+                      (
+                       src0t == GGML_TYPE_F16  || // TODO: helper function
+                       src0t == GGML_TYPE_Q4_0 ||
+                       src0t == GGML_TYPE_Q4_1 ||
+                       src0t == GGML_TYPE_Q5_0 ||
+                       src0t == GGML_TYPE_Q5_1 ||
+                       src0t == GGML_TYPE_Q8_0 ||
+                       src0t == GGML_TYPE_IQ4_NL ||
+                       false) && (ne11 >= 2 && ne11 <= 8)
+                     ) ||
+                     (
+                      (
+                       src0t == GGML_TYPE_Q4_K ||
+                       src0t == GGML_TYPE_Q5_K ||
+                       src0t == GGML_TYPE_Q6_K ||
+                       false) && (ne11 >= 4 && ne11 <= 8)
+                     )
+                    )
+                   ) {
+                    // TODO: determine the optimal parameters based on grid utilization
+                    //       I still don't know why we should not always use the maximum available threads:
+                    //
+                    //       nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
+                    //
+                    //       my current hypothesis is that the work grid is not evenly divisible for different nsg
+                    //       values and there can be some tail effects when nsg is high. need to confirm this
+                    //
+                    const int nsg    = 2;                 // num simdgroups per threadgroup
+                    const int nxpsg  = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+                    const int nypsg  = 32/nxpsg;          // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
+                    const int r0ptg  = nypsg*nsg;         // num src0 rows per threadgroup
+                          int r1ptg  = 4;                 // num src1 rows per threadgroup
+
+                    // note: not sure how optimal are those across all different hardware. there might be someting cleverer
+                    switch (ne11) {
+                        case 2:
+                            r1ptg = 2; break;
+                        case 3:
+                        case 6:
+                            r1ptg = 3; break;
+                        case 4:
+                        case 7:
+                        case 8:
+                            r1ptg = 4; break;
+                        case 5:
+                            r1ptg = 5; break;
+                    };
+
+                    id pipeline = nil;
+
+                    switch (src0->type) {
+                        case GGML_TYPE_F16:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q4_K:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q5_K:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q6_K:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_IQ4_NL:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        default: GGML_ABORT("not implemented");
+                    }
+
+                    ggml_metal_kargs_mul_mv_ext args = {
+                        /*.ne00  =*/ ne00,
+                        /*.ne01  =*/ ne01,
+                        /*.ne02  =*/ ne02,
+                        /*.nb00  =*/ nb00,
+                        /*.nb01  =*/ nb01,
+                        /*.nb02  =*/ nb02,
+                        /*.nb03  =*/ nb03,
+                        /*.ne10  =*/ ne10,
+                        /*.ne11  =*/ ne11,
+                        /*.ne12  =*/ ne12,
+                        /*.nb10  =*/ nb10,
+                        /*.nb11  =*/ nb11,
+                        /*.nb12  =*/ nb12,
+                        /*.nb13  =*/ nb13,
+                        /*.ne0   =*/ ne0,
+                        /*.ne1   =*/ ne1,
+                        /*.r2    =*/ r2,
+                        /*.r3    =*/ r3,
+                        /*.nsg   =*/ nsg,
+                        /*.nxpsg =*/ nxpsg,
+                        /*.r1ptg =*/ r1ptg,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+
+                    //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                } else
+                // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                if ([device supportsFamily:MTLGPUFamilyApple7] &&
+                        !ggml_is_transposed(src0) &&
+                        !ggml_is_transposed(src1) &&
+                        src1t == GGML_TYPE_F32 &&
+                        ne00 % 32 == 0 && ne00 >= 64 &&
+                        (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
+                    //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                    // some Metal matrix data types require aligned pointers
+                    // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
+                        case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
+                        case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
+                        default: break;
+                    }
+
+                    id pipeline = nil;
+
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break;
+                        case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break;
+                        case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
+                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
+                        default: GGML_ABORT("MUL MAT-MAT not implemented");
+                    }
+
+                    ggml_metal_kargs_mul_mm args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne02 =*/ ne02,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne12 =*/ ne12,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.nb13 =*/ nb13,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.r2   =*/ r2,
+                        /*.r3   =*/ r3,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
+                    [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                    [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                } else {
+                    int nth0 = 32;
+                    int nth1 = 1;
+                    int nrows = 1;
+                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                    id pipeline = nil;
+
+                    // use custom matrix x vector kernel
+                    switch (src0t) {
+                        case GGML_TYPE_F32:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
+                                nrows = 4;
+                            } break;
+                        case GGML_TYPE_F16:
+                            {
+                                nth0 = 32;
+                                nth1 = 1;
+                                if (src1t == GGML_TYPE_F32) {
+                                    if (ne11 * ne12 < 4) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
+                                    } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
+                                        nrows = ne11;
+                                    } else {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
+                                        nrows = 4;
+                                    }
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
+                                    nrows = 4;
+                                }
+                            } break;
+                        case GGML_TYPE_BF16:
+                            {
+                                nth0 = 32;
+                                nth1 = 1;
+                                if (src1t == GGML_TYPE_F32) {
+                                    if (ne11 * ne12 < 4) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
+                                    } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
+                                        nrows = ne11;
+                                    } else {
+                                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
+                                        nrows = 4;
+                                    }
+                                } else {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
+                                    nrows = 4;
+                                }
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q2_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q3_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_K:
+                            {
+                                nth0 = 4; //1;
+                                nth1 = 8; //32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q6_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_M:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_NL:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
+                            } break;
+                        default:
+                            {
+                                GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+                                GGML_ABORT("not implemented");
+                            }
+                    };
+
+                    ggml_metal_kargs_mul_mv args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne10 =*/ ne10,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.nb13 =*/ nb13,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.r2   =*/ r2,
+                        /*.r3   =*/ r3,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+
+                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                        src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                        src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+                        const int mem_size = 32*sizeof(float);
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q4_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q3_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q5_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q6_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    } else {
+                        const int64_t ny = (ne11 + nrows - 1)/nrows;
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                }
+            } break;
+        case GGML_OP_MUL_MAT_ID:
+            {
+                const int n_as = src0->ne[2];
+
+                // src2 = ids
+                const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
+
+                GGML_ASSERT(src2t == GGML_TYPE_I32);
+
+                GGML_ASSERT(!ggml_is_transposed(src0));
+                GGML_ASSERT(!ggml_is_transposed(src1));
+
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+
+                GGML_ASSERT(ne03 == 1);
+                GGML_ASSERT(ne13 == 1);
+
+                // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                // to the matrix-vector kernel
+                // ne20 = n_used_experts
+                // ne21 = n_rows
+                const int dst_rows = ne20*ne21;
+                const int dst_rows_min = n_as;
+                const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+
+                // max size of the rowids array in the kernel shared buffer
+                GGML_ASSERT(dst_rows <= dst_rows_max);
+
+                // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+                // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
+                // !!!
+                // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+                //       indirect matrix multiplication
+                // !!!
+                if ([device supportsFamily:MTLGPUFamilyApple7] &&
+                        ne00 % 32 == 0 && ne00 >= 64 &&
+                        dst_rows > dst_rows_min) {
+                    // some Metal matrix data types require aligned pointers
+                    // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break;
+                        case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break;
+                        case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break;
+                        default: break;
+                    }
+
+                    id pipeline = nil;
+
+                    switch (src0->type) {
+                        case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
+                        case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
+                        case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32   ].pipeline; break;
+                        case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32   ].pipeline; break;
+                        case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32   ].pipeline; break;
+                        case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
+                        case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
+                        case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32  ].pipeline; break;
+                        case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
+                        case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
+                        default: GGML_ABORT("MUL_MAT_ID not implemented");
+                    }
+
+                    ggml_metal_kargs_mul_mm_id args = {
+                        /*.nei0 =*/ ne20,
+                        /*.nei1 =*/ ne21,
+                        /*.nbi1 =*/ nb21,
+                        /*.ne00 =*/ ne00,
+                        /*.ne02 =*/ ne02,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.ne13 =*/ ne13,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:4];
+
+                    [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                } else {
+                    int nth0 = 32;
+                    int nth1 = 1;
+                    int nrows = 1;
+                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                    id pipeline = nil;
+
+                    // use custom matrix x vector kernel
+                    switch (src0t) {
+                        case GGML_TYPE_F32:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_F16:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nth0 = 32;
+                                nth1 = 1;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_BF16:
+                            {
+                                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nth0 = 32;
+                                nth1 = 1;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            {
+                                nth0 = 8;
+                                nth1 = 8;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q2_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q3_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q4_K:
+                            {
+                                nth0 = 4; //1;
+                                nth1 = 8; //32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q5_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_Q6_K:
+                            {
+                                nth0 = 2;
+                                nth1 = 32;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_XXS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ3_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ2_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_S:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ1_M:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_NL:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
+                            } break;
+                        case GGML_TYPE_IQ4_XS:
+                            {
+                                nth0 = 4;
+                                nth1 = 16;
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
+                            } break;
+                        default:
+                            {
+                                GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t);
+                                GGML_ABORT("not implemented");
+                            }
+                    };
+
+                    if (ggml_is_quantized(src0t)) {
+                        GGML_ASSERT(ne00 >= nth0*nth1);
+                    }
+
+                    ggml_metal_kargs_mul_mv_id args = {
+                        /*.nei0 =*/ ne20,
+                        /*.nei1 =*/ ne21,
+                        /*.nbi1 =*/ nb21,
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.ne10 =*/ ne10,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.ne13 =*/ ne13,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.nb1  =*/ nb1,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
+
+                    const int64_t _ne1 = 1;
+                    const int tgz = dst_rows;
+
+                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
+                            src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                            src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
+                        const int mem_size = 32*sizeof(float);
+                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q4_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q3_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q5_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                    else if (src0t == GGML_TYPE_Q6_K) {
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    } else {
+                        const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    }
+                }
+            } break;
+        case GGML_OP_GET_ROWS:
+            {
+                id pipeline = nil;
+
+                switch (src0->type) {
+                    case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32    ].pipeline; break;
+                    case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16    ].pipeline; break;
+                    case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16   ].pipeline; break;
+                    case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0   ].pipeline; break;
+                    case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1   ].pipeline; break;
+                    case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
+                    case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
+                    case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
+                    case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
+                    case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
+                    case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
+                    case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K   ].pipeline; break;
+                    case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
+                    case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+                    case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+                    case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break;
+                    case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S  ].pipeline; break;
+                    case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S  ].pipeline; break;
+                    case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S  ].pipeline; break;
+                    case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M  ].pipeline; break;
+                    case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
+                    case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
+                    case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
+                    default: GGML_ABORT("not implemented");
+                }
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
+                [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
+                [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+                [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+                [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+                [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+                [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+                [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+                [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
+                [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+            } break;
+        case GGML_OP_RMS_NORM:
+            {
+                GGML_ASSERT(ne00 % 4 == 0);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+                float eps;
+                memcpy(&eps, dst->op_params, sizeof(float));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
+                int nth = 32; // SIMD width
+
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_rms_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                const int64_t nrows = ggml_nrows(src0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_GROUP_NORM:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+
+                float eps;
+                memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+                const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
+
+                int nth = 32; // SIMD width
+
+                //while (nth < ne00/4 && nth < 1024) {
+                //    nth *= 2;
+                //}
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
+                [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
+                [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
+                [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
+                [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
+                [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+                [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_NORM:
+            {
+                GGML_ASSERT(ne00 % 4 == 0);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+                float eps;
+                memcpy(&eps, dst->op_params, sizeof(float));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
+
+                int nth = 32; // SIMD width
+
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+                const int64_t nrows = ggml_nrows(src0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ROPE:
+            {
+                // make sure we have one or more position id(ne10) per token(ne02)
+                GGML_ASSERT(ne10 % ne02 == 0);
+                GGML_ASSERT(ne10 >= ne02);
+
+                const int nth = MIN(1024, ne00);
+
+                const int n_past     = ((const int32_t *) dst->op_params)[0];
+                const int n_dims     = ((const int32_t *) dst->op_params)[1];
+                const int mode       = ((const int32_t *) dst->op_params)[2];
+                // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
+                const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
+
+                float freq_base;
+                float freq_scale;
+                float ext_factor;
+                float attn_factor;
+                float beta_fast;
+                float beta_slow;
+
+                memcpy(&freq_base,   (const int32_t *) dst->op_params +  5, sizeof(float));
+                memcpy(&freq_scale,  (const int32_t *) dst->op_params +  6, sizeof(float));
+                memcpy(&ext_factor,  (const int32_t *) dst->op_params +  7, sizeof(float));
+                memcpy(&attn_factor, (const int32_t *) dst->op_params +  8, sizeof(float));
+                memcpy(&beta_fast,   (const int32_t *) dst->op_params +  9, sizeof(float));
+                memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float));
+
+                const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+
+                id pipeline = nil;
+
+                if (!is_neox) {
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                } else {
+                    switch (src0->type) {
+                        case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
+                        case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
+                        default: GGML_ABORT("fatal error");
+                    };
+                }
+
+                ggml_metal_kargs_rope args = {
+                    /*.ne00        =*/ ne00,
+                    /*.ne01        =*/ ne01,
+                    /*.ne02        =*/ ne02,
+                    /*.ne03        =*/ ne03,
+                    /*.nb00        =*/ nb00,
+                    /*.nb01        =*/ nb01,
+                    /*.nb02        =*/ nb02,
+                    /*.nb03        =*/ nb03,
+                    /*.ne0         =*/ ne0,
+                    /*.ne1         =*/ ne1,
+                    /*.ne2         =*/ ne2,
+                    /*.ne3         =*/ ne3,
+                    /*.nb0         =*/ nb0,
+                    /*.nb1         =*/ nb1,
+                    /*.nb2         =*/ nb2,
+                    /*.nb3         =*/ nb3,
+                    /*.n_past      =*/ n_past,
+                    /*.n_dims      =*/ n_dims,
+                    /*.n_ctx_orig  =*/ n_ctx_orig,
+                    /*.freq_base   =*/ freq_base,
+                    /*.freq_scale  =*/ freq_scale,
+                    /*.ext_factor  =*/ ext_factor,
+                    /*.attn_factor =*/ attn_factor,
+                    /*.beta_fast   =*/ beta_fast,
+                    /*.beta_slow   =*/ beta_slow,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
+                if (id_src2 != nil) {
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
+                }
+                [encoder setBuffer:id_dst  offset:offs_dst      atIndex:4];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_IM2COL:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
+                GGML_ASSERT(src0->type == GGML_TYPE_F16);
+                GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+                const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+                const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+                const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
+                const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
+                const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
+                const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
+
+                const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
+
+                const int32_t N  = src1->ne[is_2D ? 3 : 2];
+                const int32_t IC = src1->ne[is_2D ? 2 : 1];
+                const int32_t IH = is_2D ? src1->ne[1] : 1;
+                const int32_t IW =         src1->ne[0];
+
+                const int32_t KH = is_2D ? src0->ne[1] : 1;
+                const int32_t KW =         src0->ne[0];
+
+                const int32_t OH = is_2D ? dst->ne[2] : 1;
+                const int32_t OW =         dst->ne[1];
+
+                const int32_t CHW = IC * KH * KW;
+
+                const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
+
+                const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
+
+                switch (dst->type) {
+                    case GGML_TYPE_F32: {
+                        pipeline = (is_gt_mttpt ?
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
+                                    :
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
+                    } break;
+                    case GGML_TYPE_F16: {
+                        pipeline = (is_gt_mttpt ?
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
+                                    :
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
+                    } break;
+                    default: GGML_ABORT("fatal error");
+                };
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
+                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
+                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
+                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
+                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
+                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
+                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
+                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
+                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
+                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
+                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
+                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
+
+                if (is_gt_mttpt) {
+                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
+                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
+                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
+
+                    const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
+
+                    const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
+                } else {
+                    [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                }
+            } break;
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
+                GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
+                GGML_ASSERT(src1->type == GGML_TYPE_F32);
+                GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+                const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+
+                const int32_t IC = src1->ne[1];
+                const int32_t IL = src1->ne[0];
+
+                const int32_t K  = src0->ne[0];
+
+                const int32_t OL = dst->ne[0];
+                const int32_t OC = dst->ne[1];
+
+                id pipeline;
+
+                switch (src0->type) {
+                    case GGML_TYPE_F32: {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
+                    } break;
+                    case GGML_TYPE_F16: {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
+                    } break;
+                    default: GGML_ABORT("fatal error");
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0         atIndex:0];
+                [encoder setBuffer:id_src1 offset:offs_src1         atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst          atIndex:2];
+                [encoder setBytes:&IC      length:sizeof( int32_t)  atIndex:3];
+                [encoder setBytes:&IL      length:sizeof( int32_t)  atIndex:4];
+                [encoder setBytes:&K       length:sizeof( int32_t)  atIndex:5];
+                [encoder setBytes:&s0      length:sizeof( int32_t)  atIndex:6];
+                [encoder setBytes:&nb0     length:sizeof(uint64_t)  atIndex:7];
+                [encoder setBytes:&nb1     length:sizeof(uint64_t)  atIndex:8];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_UPSCALE:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                const float sf0 = (float)ne0/src0->ne[0];
+                const float sf1 = (float)ne1/src0->ne[1];
+                const float sf2 = (float)ne2/src0->ne[2];
+                const float sf3 = (float)ne3/src0->ne[3];
+
+                const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
+                [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
+                [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
+                [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_PAD:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_PAD_REFLECT_1D:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
+                const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:6];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:11];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:12];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:13];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:14];
+                [encoder setBytes:&p0   length:sizeof(p0)   atIndex:15];
+                [encoder setBytes:&p1   length:sizeof(p1)   atIndex:16];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_UNPAD:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UNPAD_F32].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
+                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
+                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
+                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
+                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
+                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
+                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
+                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ARANGE:
+            {
+                GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+                float start;
+                float step;
+
+                memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
+                memcpy(&step,  ((const int32_t *) dst->op_params) + 2, sizeof(float));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
+                [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
+                [encoder setBytes:&start length:sizeof(start) atIndex:2];
+                [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
+
+                const int nth = MIN(1024, ne0);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                const int dim        = dst->op_params[0];
+                const int max_period = dst->op_params[1];
+
+                const int half = dim / 2;
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
+                [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
+                [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+
+                const int nth = MIN(1024, half);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_ARGSORT:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+                const int nrows = ggml_nrows(src0);
+
+                enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+                // bitonic sort requires the number of elements to be power of 2
+                int64_t ne00_padded = 1;
+                while (ne00_padded < ne00) {
+                    ne00_padded *= 2;
+                }
+
+                // Metal kernels require the buffer size to be multiple of 16 bytes
+                // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+                const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+
+                id pipeline = nil;
+
+                switch (order) {
+                    case GGML_SORT_ORDER_ASC:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline;  break;
+                    case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                };
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
+            } break;
+        case GGML_OP_LEAKY_RELU:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+                float slope;
+                memcpy(&slope, dst->op_params, sizeof(float));
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
+                [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+                const int64_t n = ggml_nelements(dst);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                GGML_ASSERT(ne00 % 4  == 0);
+                GGML_ASSERT(ne11 % 32 == 0);
+
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                GGML_ASSERT(src1->type == src2->type);
+
+                GGML_ASSERT(ggml_are_same_shape (src1, src2));
+
+                struct ggml_tensor * src3 = node->src[3];
+
+                size_t offs_src3 = 0;
+
+                id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+
+                GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
+                GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
+                        "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
+
+                const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
+                //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
+                const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
+                const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
+
+                const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
+                const uint64_t nb31 = src3 ? src3->nb[1] : 0;
+                const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
+                const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
+
+                const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+
+                float scale;
+                float max_bias;
+                float logit_softcap;
+                memcpy(&scale,         ((const int32_t *) dst->op_params) + 0, sizeof(scale));
+                memcpy(&max_bias,      ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
+                memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
+
+                if (logit_softcap != 0.0f) {
+                    scale /= logit_softcap;
+                }
+
+                const uint32_t n_head      = src0->ne[2];
+                const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+                const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+                id pipeline = nil;
+
+                bool use_vec_kernel = false;
+
+                // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
+                //       for now avoiding mainly to keep the number of templates/kernels a bit lower
+                if (ne01 >= 4 || (ne00%128 != 0)) {
+                    switch (src1->type) {
+                        case GGML_TYPE_F16:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_BF16:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            {
+                                switch (ne00) {
+                                    case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
+                                    case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
+                                    case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
+                                    case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
+                                    case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
+                                    case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
+                                    default:
+                                              {
+                                                  GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                                  GGML_LOG_ERROR("add template specialization for this size\n");
+                                                  GGML_ABORT("add template specialization for this size");
+                                              }
+                                }
+                            } break;
+                        default:
+                            {
+                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                GGML_ABORT("add template specialization for this type");
+                            }
+                    }
+                } else {
+                    use_vec_kernel = true;
+
+                    switch (ne00) {
+                        case 128:
+                            {
+                                switch (src1->type) {
+                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
+                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
+                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
+                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
+                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
+                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
+                                    default:
+                                        {
+                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                            GGML_LOG_ERROR("add template specialization for this type\n");
+                                            GGML_ABORT("add template specialization for this type");
+                                        }
+                                }
+                            } break;
+                        case 256:
+                            {
+                                switch (src1->type) {
+                                    case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
+                                    case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
+                                    case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
+                                    case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
+                                    case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
+                                    case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
+                                    case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
+                                    default:
+                                        {
+                                            GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                            GGML_LOG_ERROR("add template specialization for this type\n");
+                                            GGML_ABORT("add template specialization for this type");
+                                        }
+                                }
+                            } break;
+                        default:
+                                  {
+                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                      GGML_LOG_ERROR("add template specialization for this size\n");
+                                      GGML_ABORT("add template specialization for this size");
+                                  }
+                    }
+                }
+
+                ggml_metal_kargs_flash_attn_ext args = {
+                    /*.ne01          =*/ ne01,
+                    /*.ne02          =*/ ne02,
+                    /*.ne03          =*/ ne03,
+                    /*.nb01          =*/ nb01,
+                    /*.nb02          =*/ nb02,
+                    /*.nb03          =*/ nb03,
+                    /*.ne11          =*/ ne11,
+                    /*.ne_12_2       =*/ ne12,
+                    /*.ne_12_3       =*/ ne13,
+                    /*.nb_12_1       =*/ nb11,
+                    /*.nb_12_2       =*/ nb12,
+                    /*.nb_12_3       =*/ nb13,
+                    /*.nb31          =*/ nb31,
+                    /*.ne1           =*/ ne1,
+                    /*.ne2           =*/ ne2,
+                    /*.scale         =*/ scale,
+                    /*.max_bias      =*/ max_bias,
+                    /*.m0            =*/ m0,
+                    /*.m1            =*/ m1,
+                    /*.n_head_log2   =*/ n_head_log2,
+                    /*.logit_softcap =*/ logit_softcap,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
+                [encoder setBuffer:id_src2 offset:offs_src2     atIndex:3];
+                if (id_src3) {
+                    [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+                }
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:5];
+
+                if (!use_vec_kernel) {
+                    // half8x8 kernel
+                    const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
+                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                    GGML_ASSERT(nqptg <= 32);
+                    GGML_ASSERT(nqptg  % 8  == 0);
+                    GGML_ASSERT(ncpsg  % 32 == 0);
+
+                    // 2*(2*ncpsg + nqptg)*(nsg)
+                    // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+                    //
+                    // 16*32*(nsg)
+                    // the shared memory needed for the simdgroups to load the KV cache
+                    // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
+                    //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+
+                    int64_t nsgmax = 2;
+
+                    while (true) {
+                        const size_t smem = FATTN_SMEM(nsgmax);
+                        if (smem > device.maxThreadgroupMemoryLength) {
+                            break;
+                        }
+                        nsgmax *= 2;
+                    }
+                    nsgmax /= 2;
+
+                    // simdgroups per threadgroup (a.k.a. warps)
+                    const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+
+                    const size_t smem = FATTN_SMEM(nsg);
+
+                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                } else {
+                    // half4x4 kernel
+                    const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
+                    const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+
+                    GGML_ASSERT(nqptg <= 32);
+                    GGML_ASSERT(nqptg  % 1  == 0);
+                    GGML_ASSERT(ncpsg  % 32 == 0);
+
+                    // ne00 + 2*ncpsg*(nsg)
+                    // for each query, we load it as f16 in shared memory (ne00)
+                    // and store the soft_max values and the mask
+                    //
+                    // ne00*(nsg)
+                    // each simdgroup has a full f16 head vector in shared mem to accumulate results
+                    //
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
+
+                    int64_t nsgmax = 2;
+
+                    while (true) {
+                        const size_t smem = FATTN_SMEM(nsgmax);
+                        if (smem > device.maxThreadgroupMemoryLength) {
+                            break;
+                        }
+                        nsgmax *= 2;
+                    }
+                    nsgmax /= 2;
+
+                    // simdgroups per threadgroup (a.k.a. warps)
+                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+
+                    int64_t nsg = 1;
+                    while (nsg <= nsgt) {
+                        nsg *= 2;
+                    }
+                    nsg /= 2;
+
+                    const size_t smem = FATTN_SMEM(nsg);
+
+                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
+#undef FATTN_SMEM
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                }
+            } break;
+        case GGML_OP_DUP:
+        case GGML_OP_CPY:
+        case GGML_OP_CONT:
+            {
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+
+                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
+                id pipeline = nil;
+
+                switch (src0t) {
+                    case GGML_TYPE_F32:
+                        {
+                            GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
+
+                            switch (dstt) {
+                                case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+                                case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
+                                case GGML_TYPE_BF16:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
+                                case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+                                case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+                                case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+                                case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+                                case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
+                                case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_F16:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
+                                case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_BF16:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
+                                case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
+                                default: GGML_ASSERT(false && "not implemented");
+                            };
+                        } break;
+                    default: GGML_ABORT("not implemented");
+                }
+
+                ggml_metal_kargs_cpy args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_SET:
+            {
+                GGML_ASSERT(ggml_are_same_shape(src0, dst));
+                GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+                // src0 and dst as viewed during set
+                const size_t dst_nb0 = ggml_element_size(src0);
+
+                const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
+                const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
+                const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
+                const size_t offset  = ((int32_t *) dst->op_params)[3];
+                const bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+                if (!inplace) {
+                    memcpy(((char *)  dst->data), ((char *) src0->data), ggml_nbytes(dst));
+                }
+
+                const int im0 = (ne10 == 0 ? 0 : ne10-1);
+                const int im1 = (ne11 == 0 ? 0 : ne11-1);
+                const int im2 = (ne12 == 0 ? 0 : ne12-1);
+                const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+                GGML_ASSERT(offset + im0*dst_nb0  + im1*dst_nb1  + im2*dst_nb2  + im3*dst_nb3  <= ggml_nbytes(dst));
+
+                id pipeline = nil;
+
+                switch (src0t) {
+                    case GGML_TYPE_F32:
+                        GGML_ASSERT(nb10 == sizeof(float));
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
+                    case GGML_TYPE_I32:
+                        GGML_ASSERT(nb10 == sizeof(int32_t));
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                }
+
+                ggml_metal_kargs_set args = {
+                    /*.ne10    =*/ ne10,
+                    /*.ne11    =*/ ne11,
+                    /*.ne12    =*/ ne12,
+                    /*.nb10    =*/ nb10,
+                    /*.nb11    =*/ nb11,
+                    /*.nb12    =*/ nb12,
+                    /*.nb13    =*/ nb13,
+                    /*.nb1     =*/ dst_nb1,
+                    /*.nb2     =*/ dst_nb2,
+                    /*.nb3     =*/ dst_nb3,
+                    /*.offs    =*/ offset,
+                    /*.inplace =*/ inplace,
+                };
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+        case GGML_OP_POOL_2D:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
+
+                const int32_t * opts = dst->op_params;
+                enum ggml_op_pool op = opts[0];
+
+                id pipeline = nil;
+                switch (src0t) {
+                    case GGML_TYPE_F32: {
+                        switch(op) {
+                            case GGML_OP_POOL_AVG:
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
+                            case GGML_OP_POOL_MAX:
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
+                            default: GGML_ASSERT(false && "not implemented");
+                        }
+                    } break;
+                    default: GGML_ASSERT(false && "not implemented");
+                }
+
+                const int32_t k0 = opts[1];
+                const int32_t k1 = opts[2];
+                const int32_t s0 = opts[3];
+                const int32_t s1 = opts[4];
+                const int32_t p0 = opts[5];
+                const int32_t p1 = opts[6];
+
+                const int64_t IH = src0->ne[1];
+                const int64_t IW = src0->ne[0];
+
+                const int64_t N  = dst->ne[3];
+                const int64_t OC = dst->ne[2];
+                const int64_t OH = dst->ne[1];
+                const int64_t OW = dst->ne[0];
+
+                const int64_t parallel_elements = N * OC * OH * OW;
+                const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
+                const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
+
+                // TODO: add ggml_metal_kargs struct
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
+                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
+                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
+                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
+                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
+                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
+                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
+                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
+                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
+                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
+                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
+                [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
+            } break;
+            case GGML_OP_ARGMAX:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+                GGML_ASSERT(nb00 == ggml_type_size(src0->type));
+
+                const int64_t nrows = ggml_nrows(src0);
+
+                int nth = 32; // SIMD width
+                while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+                    nth *= 2;
+                }
+
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float)   atIndex:0];
+                [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
+       default:
+            {
+                GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+static enum ggml_status ggml_metal_graph_compute(
+            ggml_backend_t   backend,
+        struct ggml_cgraph * gf) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    // number of nodes encoded by the main thread (empirically determined)
+    const int n_main = 128;
+
+    // number of threads in addition to the main thread
+    const int n_cb = ctx->n_cb;
+
+    // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
+    // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
+    // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
+    // each thread creates it's own command buffer and enqueues the ops in parallel
+    //
+    // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
+
+    @autoreleasepool {
+        ctx->gf = gf;
+
+        ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
+        ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
+
+        ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
+
+        const bool should_capture = ctx->capture_next_compute;
+        if (should_capture) {
+            ctx->capture_next_compute = false;
+
+            if (!ctx->capture_started) {
+                // create capture scope
+                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
+
+                MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
+                descriptor.captureObject = ctx->capture_scope;
+                descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
+                descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
+
+                NSError * error = nil;
+                if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
+                    GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
+                } else {
+                    [ctx->capture_scope beginScope];
+                    ctx->capture_started = true;
+                }
+            }
+        }
+
+        // the main thread commits the first few commands immediately
+        // command_buffer[n_cb]
+        {
+            id command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->command_buffers[n_cb] = command_buffer;
+
+            [command_buffer enqueue];
+            ctx->encode_async(n_cb);
+        }
+
+        // prepare the rest of the command buffers asynchronously
+        // command_buffer[0.. n_cb)
+        for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+            id command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->command_buffers[cb_idx] = command_buffer;
+
+            // always enqueue the first two command buffers
+            // enqueue all of the command buffers if we don't need to abort
+            if (cb_idx < 2 || ctx->abort_callback == NULL) {
+                [command_buffer enqueue];
+            }
+        }
+
+        dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
+
+        // wait for completion and check status of each command buffer
+        // needed to detect if the device ran out-of-memory for example (#1881)
+        {
+            id command_buffer = ctx->command_buffers[n_cb];
+            [command_buffer waitUntilCompleted];
+
+            MTLCommandBufferStatus status = [command_buffer status];
+            if (status != MTLCommandBufferStatusCompleted) {
+                GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
+                if (status == MTLCommandBufferStatusError) {
+                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                }
+
+                return GGML_STATUS_FAILED;
+            }
+        }
+
+        for (int i = 0; i < n_cb; ++i) {
+            id command_buffer = ctx->command_buffers[i];
+            [command_buffer waitUntilCompleted];
+
+            MTLCommandBufferStatus status = [command_buffer status];
+            if (status != MTLCommandBufferStatusCompleted) {
+                GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
+                if (status == MTLCommandBufferStatusError) {
+                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                }
+
+                return GGML_STATUS_FAILED;
+            }
+
+            id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
+            if (!next_buffer) {
+                continue;
+            }
+
+            const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+            if (next_queued) {
+                continue;
+            }
+
+            if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+                GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
+                return GGML_STATUS_ABORTED;
+            }
+
+            [next_buffer commit];
+        }
+
+        if (!should_capture && ctx->capture_started) {
+            [ctx->capture_scope endScope];
+            [[MTLCaptureManager sharedCaptureManager] stopCapture];
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend interface
+
+static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    for (int i = 0; i < ctx->n_buffers; i++) {
+        [ctx->buffers[i].metal release];
+    }
+    ggml_backend_metal_device_rel(buffer->buft->device->context);
+
+    if (ctx->owned) {
+#if TARGET_OS_OSX
+        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
+#else
+        free(ctx->all_data);
+#endif
+    }
+
+    free(ctx);
+}
+
+static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    return ctx->all_data;
+}
+
+static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    memcpy((char *)tensor->data + offset, data, size);
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    UNUSED(buffer);
+}
+
+static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        memcpy(dst->data, src->data, ggml_nbytes(src));
+        return true;
+    }
+    return false;
+
+    UNUSED(buffer);
+}
+
+static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
+
+    memset(ctx->all_data, value, ctx->all_size);
+}
+
+static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
+    /* .free_buffer     = */ ggml_backend_metal_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_metal_buffer_get_base,
+    /* .init_tensor     = */ NULL,
+    /* .memset_tensor   = */ ggml_backend_metal_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_metal_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_metal_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_metal_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_metal_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// default buffer type
+
+static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "Metal";
+
+    UNUSED(buft);
+}
+
+static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) {
+#ifndef GGML_METAL_NDEBUG
+#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
+    if (@available(macOS 10.12, iOS 16.0, *)) {
+        GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n",
+                __func__,
+                size_aligned / 1024.0 / 1024.0,
+                device.currentAllocatedSize / 1024.0 / 1024.0,
+                device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+
+        if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
+            GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
+        }
+    } else {
+        GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
+                __func__,
+                size_aligned / 1024.0 / 1024.0,
+                device.currentAllocatedSize / 1024.0 / 1024.0);
+    }
+#endif
+#endif
+    UNUSED(device);
+    UNUSED(size_aligned);
+}
+
+static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+
+    const size_t size_page = sysconf(_SC_PAGESIZE);
+
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    id device = ggml_backend_metal_device_acq(buft->device->context);
+
+    ctx->all_data = ggml_metal_host_malloc(size_aligned);
+    ctx->all_size = size_aligned;
+    ctx->owned = true;
+    ctx->n_buffers = 1;
+
+    if (ctx->all_data != NULL) {
+        ctx->buffers[0].data  = ctx->all_data;
+        ctx->buffers[0].size  = size;
+        ctx->buffers[0].metal = nil;
+
+        if (size_aligned > 0) {
+            ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
+                                            length:size_aligned
+                                            options:MTLResourceStorageModeShared
+                                            deallocator:nil];
+        }
+    }
+
+    if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
+        GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+        free(ctx);
+        ggml_backend_metal_device_rel(buft->device->context);
+        return NULL;
+    }
+
+    //ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+    return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
+}
+
+static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 32;
+    UNUSED(buft);
+}
+
+static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    id device = ggml_backend_metal_device_acq(buft->device->context);
+    const size_t max_size = device.maxBufferLength;
+    ggml_backend_metal_device_rel(buft->device->context);
+
+    return max_size;
+
+    UNUSED(buft);
+}
+
+static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return true;
+
+    UNUSED(buft);
+}
+
+ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
+        /* .iface = */ {
+            /* .get_name         = */ ggml_backend_metal_buffer_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
+            /* .get_max_size     = */ ggml_backend_metal_buffer_type_get_max_size,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
+        },
+        /* .device  = */ &g_ggml_backend_metal_device,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_buffer_type_metal;
+}
+
+static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "Metal_Mapped";
+
+    UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = {
+        /* .iface = */ {
+            /* .get_name         = */ ggml_backend_metal_buffer_from_ptr_type_get_name,
+            /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_metal_buffer_type_get_alignment,
+            /* .get_max_size     = */ ggml_backend_metal_buffer_type_get_max_size,
+            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
+        },
+        /* .device  = */ &g_ggml_backend_metal_device,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_buffer_from_ptr_type_metal;
+}
+
+// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
+ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
+    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+
+    ctx->all_data = data;
+    ctx->all_size = size;
+    ctx->owned = false;
+    ctx->n_buffers = 0;
+
+    const size_t size_page = sysconf(_SC_PAGESIZE);
+
+    // page-align the data ptr
+    {
+        const uintptr_t offs = (uintptr_t) data % size_page;
+        data  = (void *) ((char *) data - offs);
+        size += offs;
+    }
+
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    id device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
+
+    // the buffer fits into the max buffer size allowed by the device
+    if (size_aligned <= device.maxBufferLength) {
+        ctx->buffers[ctx->n_buffers].data  = data;
+        ctx->buffers[ctx->n_buffers].size  = size;
+        ctx->buffers[ctx->n_buffers].metal = nil;
+
+        if (size_aligned > 0) {
+            ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+            if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+                return false;
+            }
+        }
+
+        ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+        ++ctx->n_buffers;
+    } else {
+        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+        // one of the views
+        const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+        const size_t size_step = device.maxBufferLength - size_ovlp;
+        const size_t size_view = device.maxBufferLength;
+
+        for (size_t i = 0; i < size; i += size_step) {
+            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+            ctx->buffers[ctx->n_buffers].data  = (void *) ((uint8_t *) data + i);
+            ctx->buffers[ctx->n_buffers].size  = size_step_aligned;
+            ctx->buffers[ctx->n_buffers].metal = nil;
+
+            if (size_step_aligned > 0) {
+                ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+                if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                    GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
+                    return false;
+                }
+            }
+
+            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
+            if (i + size_step < size) {
+                GGML_LOG_INFO("\n");
+            }
+
+            ++ctx->n_buffers;
+        }
+    }
+
+    return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
+}
+
+// backend
+
+static const char * ggml_backend_metal_name(ggml_backend_t backend) {
+    return "Metal";
+
+    UNUSED(backend);
+}
+
+static void ggml_backend_metal_free(ggml_backend_t backend) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    ggml_backend_metal_device_rel(ctx_dev);
+    ggml_metal_free(ctx);
+
+    free(backend);
+}
+
+static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    return ggml_metal_graph_compute(backend, cgraph);
+}
+
+static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+
+    if (ctx->n_cb != n_cb) {
+        ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
+
+        if (ctx->n_cb > 2) {
+            GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
+        }
+    }
+
+    if (ctx->encode_async) {
+        Block_release(ctx->encode_async);
+    }
+
+    ctx->encode_async = Block_copy(^(size_t iter) {
+        const int cb_idx = iter;
+        const int n_cb_l = ctx->n_cb;
+
+        const int n_nodes_0 = ctx->n_nodes_0;
+        const int n_nodes_1 = ctx->n_nodes_1;
+
+        const int n_nodes_per_cb = ctx->n_nodes_per_cb;
+
+        id command_buffer  = ctx->command_buffers[cb_idx];
+        id encoder = [command_buffer computeCommandEncoder];
+
+        int node_start = 0;
+        int node_end   = n_nodes_0;
+
+        if (cb_idx < n_cb_l) {
+            node_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);
+            node_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
+        }
+
+        const bool should_capture = ctx->capture_next_compute;
+
+        for (int idx = node_start; idx < node_end; ++idx) {
+            if (should_capture) {
+                [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
+            }
+
+            ggml_metal_encode_node(backend, idx, encoder);
+
+            if (should_capture) {
+                [encoder popDebugGroup];
+            }
+        }
+
+        [encoder endEncoding];
+
+        if (cb_idx < 2 || ctx->abort_callback == NULL) {
+            [command_buffer commit];
+        }
+    });
+}
+
+static struct ggml_backend_i ggml_backend_metal_i = {
+    /* .get_name                = */ ggml_backend_metal_name,
+    /* .free                    = */ ggml_backend_metal_free,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_metal_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_metal_guid(void) {
+    static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
+    return &guid;
+}
+
+// TODO: remove in the future
+ggml_backend_t ggml_backend_metal_init(void) {
+    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
+
+    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
+    if (ctx == NULL) {
+        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
+
+    *backend = (struct ggml_backend) {
+        /* .guid      = */ ggml_backend_metal_guid(),
+        /* .interface = */ ggml_backend_metal_i,
+        /* .device    = */ dev,
+        /* .context   = */ ctx,
+    };
+
+    ggml_backend_metal_set_n_cb(backend, 1);
+
+    return backend;
+}
+
+bool ggml_backend_is_metal(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
+}
+
+void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+
+    ctx->abort_callback = abort_callback;
+    ctx->abort_callback_data = user_data;
+}
+
+bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+}
+
+void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+    ctx->capture_next_compute = true;
+}
+
+// backend device
+
+static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
+    return "Metal";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
+    // acq/rel just to populate ctx->name in case it hasn't been done yet
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+    ggml_backend_metal_device_acq(ctx_dev);
+    ggml_backend_metal_device_rel(ctx_dev);
+
+    return ctx_dev->name;
+}
+
+static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    if (@available(macOS 10.12, iOS 16.0, *)) {
+        struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+        id device = ggml_backend_metal_device_acq(ctx_dev);
+
+        *total = device.recommendedMaxWorkingSetSize;
+        *free  = *total - device.currentAllocatedSize;
+
+        ggml_backend_metal_device_rel(ctx_dev);
+    } else {
+        *free = 1;
+        *total = 1;
+    }
+}
+
+static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_GPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_metal_device_get_name(dev);
+    props->description = ggml_backend_metal_device_get_description(dev);
+    props->type        = ggml_backend_metal_device_get_type(dev);
+    ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = (struct ggml_backend_dev_caps) {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
+    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
+    if (ctx == NULL) {
+        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
+
+    *backend = (struct ggml_backend) {
+        /* .guid      = */ ggml_backend_metal_guid(),
+        /* .interface = */ ggml_backend_metal_i,
+        /* .device    = */ dev,
+        /* .context   = */ ctx,
+    };
+
+    ggml_backend_metal_set_n_cb(backend, 1);
+
+    return backend;
+
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_metal_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+
+    ctx->all_data = ptr;
+    ctx->all_size = size;
+    ctx->owned = false;
+    ctx->n_buffers = 0;
+
+    const size_t size_page = sysconf(_SC_PAGESIZE);
+
+    // page-align the data ptr
+    {
+        const uintptr_t offs = (uintptr_t) ptr % size_page;
+        ptr  = (void *) ((char *) ptr - offs);
+        size += offs;
+    }
+
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+    id device = ggml_backend_metal_device_acq(ctx_dev);
+
+    // the buffer fits into the max buffer size allowed by the device
+    if (size_aligned <= device.maxBufferLength) {
+        ctx->buffers[ctx->n_buffers].data  = ptr;
+        ctx->buffers[ctx->n_buffers].size  = size;
+        ctx->buffers[ctx->n_buffers].metal = nil;
+
+        if (size_aligned > 0) {
+            ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+            if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+                return false;
+            }
+        }
+
+        ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+        ++ctx->n_buffers;
+    } else {
+        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+        // one of the views
+        const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+        const size_t size_step = device.maxBufferLength - size_ovlp;
+        const size_t size_view = device.maxBufferLength;
+
+        for (size_t i = 0; i < size; i += size_step) {
+            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+            ctx->buffers[ctx->n_buffers].data  = (void *) ((uint8_t *) ptr + i);
+            ctx->buffers[ctx->n_buffers].size  = size_step_aligned;
+            ctx->buffers[ctx->n_buffers].metal = nil;
+
+            if (size_step_aligned > 0) {
+                ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+                if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                    GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
+                    return false;
+                }
+            }
+
+            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
+            if (i + size_step < size) {
+                GGML_LOG_INFO("\n");
+            }
+
+            ++ctx->n_buffers;
+        }
+    }
+
+    return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
+}
+
+static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+    return ggml_metal_supports_op(ctx_dev, op);
+}
+
+static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
+            buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
+
+    UNUSED(dev);
+}
+
+static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    return false;
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(op);
+}
+
+static struct ggml_backend_device_i ggml_backend_metal_device_i = {
+    /* .get_name             = */ ggml_backend_metal_device_get_name,
+    /* .get_description      = */ ggml_backend_metal_device_get_description,
+    /* .get_memory           = */ ggml_backend_metal_device_get_memory,
+    /* .get_type             = */ ggml_backend_metal_device_get_type,
+    /* .get_props            = */ ggml_backend_metal_device_get_props,
+    /* .init_backend         = */ ggml_backend_metal_device_init,
+    /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_metal_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
+    /* .offload_op           = */ ggml_backend_metal_device_offload_op,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend registry
+
+static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
+    return "Metal";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    return &g_ggml_backend_metal_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static struct ggml_backend_feature g_ggml_backend_metal_features[] = {
+#if defined(GGML_METAL_EMBED_LIBRARY)
+    { "EMBED_LIBRARY", "1" },
+#endif
+#if defined(GGML_METAL_USE_BF16)
+    { "BF16", "1" },
+#endif
+    { nil, nil },
+};
+
+static struct ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) {
+    return g_ggml_backend_metal_features;
+
+    GGML_UNUSED(reg);
+}
+
+static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (strcmp(name, "ggml_backend_get_features") == 0) {
+        return (void *)ggml_backend_metal_get_features;
+    }
+
+    return NULL;
+
+    GGML_UNUSED(reg);
+}
+static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
+    /* .get_name         = */ ggml_backend_metal_reg_get_name,
+    /* .device_count     = */ ggml_backend_metal_reg_device_count,
+    /* .device_get       = */ ggml_backend_metal_reg_device_get,
+    /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_metal_reg(void) {
+    // TODO: make this thread-safe somehow?
+    {
+        g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
+            /* .api_version = */ GGML_BACKEND_API_VERSION,
+            /* .iface       = */ ggml_backend_metal_reg_i,
+            /* .context     = */ NULL,
+        };
+
+        g_ggml_backend_metal_device = (struct ggml_backend_device) {
+            /* .iface   = */ ggml_backend_metal_device_i,
+            /* .reg     = */ &g_ggml_backend_metal_reg,
+            /* .context = */ &g_ggml_ctx_dev_main,
+        };
+    }
+
+    return &g_ggml_backend_metal_reg;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)
diff --git a/llama/ggml-quants.c b/llama/ggml-quants.c
new file mode 100644
index 000000000..6f824d420
--- /dev/null
+++ b/llama/ggml-quants.c
@@ -0,0 +1,5264 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define GGML_COMMON_IMPL_C
+#include "ggml-common.h"
+
+#include "ggml-quants.h"
+#include "ggml-impl.h"
+#include "ggml-cpu-impl.h"
+#include "ggml-cpu.h"
+
+#include 
+#include 
+#include 
+#include 
+#include  // for qsort
+#include   // for GGML_ASSERT
+
+#define GROUP_MAX_EPS 1e-15f
+#define GROUP_MAX_EPS_IQ3_XXS 1e-8f
+#define GROUP_MAX_EPS_IQ2_S 1e-8f
+#define GROUP_MAX_EPS_IQ1_M 1e-7f
+#define GROUP_MAX_EPS_IQ1_S 1e-12f
+
+#if defined(_MSC_VER)
+// disable "possible loss of data" to avoid warnings for hundreds of casts
+// we should just be careful :)
+#pragma warning(disable: 4244 4267)
+#endif
+
+#define UNUSED GGML_UNUSED
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
+        }
+
+        const float d  = max / -8;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
+    const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 4) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+            y[i].qs[j]  = xi0;
+            y[i].qs[j] |= xi1 << 4;
+        }
+    }
+}
+
+void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+        float max  = 0.0f;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+                max  = v;
+            }
+        }
+
+        const float d  = max / -16;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = x[i*qk + 0    + j]*id;
+            const float x1 = x[i*qk + qk/2 + j]*id;
+
+            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(qh));
+    }
+}
+
+void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
+    const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float min = FLT_MAX;
+        float max = -FLT_MAX;
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (v < min) min = v;
+            if (v > max) max = v;
+        }
+
+        const float d  = (max - min) / ((1 << 5) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const float x0 = (x[i*qk + 0    + j] - min)*id;
+            const float x1 = (x[i*qk + qk/2 + j] - min)*id;
+
+            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+            y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2);
+        }
+
+        memcpy(&y[i].qh, &qh, sizeof(y[i].qh));
+    }
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
+    assert(k % QK8_0 == 0);
+    const int nb = k / QK8_0;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_0; j++) {
+            const float v = x[i*QK8_0 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < QK8_0; ++j) {
+            const float x0 = x[i*QK8_0 + j]*id;
+
+            y[i].qs[j] = roundf(x0);
+        }
+    }
+}
+
+// reference implementation for deterministic creation of model files
+void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
+    assert(QK8_1 == 32);
+    assert(k % QK8_1 == 0);
+    const int nb = k / QK8_1;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK8_1; j++) {
+            const float v = x[i*QK8_1 + j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        int sum = 0;
+
+        for (int j = 0; j < QK8_1/2; ++j) {
+            const float v0 = x[i*QK8_1           + j]*id;
+            const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id;
+
+            y[i].qs[          j] = roundf(v0);
+            y[i].qs[QK8_1/2 + j] = roundf(v1);
+
+            sum += y[i].qs[          j];
+            sum += y[i].qs[QK8_1/2 + j];
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(sum*d);
+    }
+}
+
+void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
+    static const int qk = QK4_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F) - 8;
+            const int x1 = (x[i].qs[j] >>   4) - 8;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) {
+    static const int qk = QK4_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int x0 = (x[i].qs[j] & 0x0F);
+            const int x1 = (x[i].qs[j] >>   4);
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int64_t k) {
+    static const int qk = QK5_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
+            const int32_t x1 = ((x[i].qs[j] >>   4) | xh_1) - 16;
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
+void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int64_t k) {
+    static const int qk = QK5_1;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
+
+        uint32_t qh;
+        memcpy(&qh, x[i].qh, sizeof(qh));
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
+            const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
+
+            const int x0 = (x[i].qs[j] & 0x0F) | xh_0;
+            const int x1 = (x[i].qs[j] >>   4) | xh_1;
+
+            y[i*qk + j + 0   ] = x0*d + m;
+            y[i*qk + j + qk/2] = x1*d + m;
+        }
+    }
+}
+
+void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int64_t k) {
+    static const int qk = QK8_0;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int j = 0; j < qk; ++j) {
+            y[i*qk + j] = x[i].qs[j]*d;
+        }
+    }
+}
+
+//
+// 2-6 bit quantization in super-blocks
+//
+
+//
+// ===================== Helper functions
+//
+static inline int nearest_int(float fval) {
+    assert(fabsf(fval) <= 4194303.f);
+    float val = fval + 12582912.f;
+    int i; memcpy(&i, &val, sizeof(int));
+    return (i & 0x007fffff) - 0x00400000;
+}
+
+static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type,
+        const float * restrict qw) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < GROUP_MAX_EPS) { // all zero
+        for (int i = 0; i < n; ++i) {
+            L[i] = 0;
+        }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (rmse_type == 0) {
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+        }
+        return 1/iscale;
+    }
+    bool return_early = false;
+    if (rmse_type < 0) {
+        rmse_type = -rmse_type;
+        return_early = true;
+    }
+    float sumlx = 0;
+    float suml2 = 0;
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 0; i < n; ++i) {
+#else
+    for (int i = 0; i < n; ++i) {
+#endif
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+        float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    float scale = suml2 ? sumlx/suml2 : 0.0f;
+    if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
+    float best = scale * sumlx;
+    for (int is = -9; is <= 9; ++is) {
+        if (is == 0) {
+            continue;
+        }
+        iscale = -(nmax + 0.1f*is) / max;
+        sumlx = suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        if (suml2 > 0 && sumlx*sumlx > best*suml2) {
+            for (int i = 0; i < n; ++i) {
+                int l = nearest_int(iscale * x[i]);
+                L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
+            }
+            scale = sumlx/suml2; best = scale*sumlx;
+        }
+    }
+    return scale;
+}
+
+static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
+    float max = 0;
+    float amax = 0;
+    for (int i = 0; i < n; ++i) {
+        float ax = fabsf(x[i]);
+        if (ax > amax) { amax = ax; max = x[i]; }
+    }
+    if (amax < GROUP_MAX_EPS) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = -nmax / max;
+    if (do_rmse) {
+        float sumlx = 0;
+        float suml2 = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale * x[i]);
+            l = MAX(-nmax, MIN(nmax-1, l));
+            L[i] = l;
+            float w = x[i]*x[i];
+            sumlx += w*x[i]*l;
+            suml2 += w*l*l;
+        }
+        for (int itry = 0; itry < 5; ++itry) {
+            int n_changed = 0;
+            for (int i = 0; i < n; ++i) {
+                float w = x[i]*x[i];
+                float slx = sumlx - w*x[i]*L[i];
+                if (slx > 0) {
+                    float sl2 = suml2 - w*L[i]*L[i];
+                    int new_l = nearest_int(x[i] * sl2 / slx);
+                    new_l = MAX(-nmax, MIN(nmax-1, new_l));
+                    if (new_l != L[i]) {
+                        slx += w*x[i]*new_l;
+                        sl2 += w*new_l*new_l;
+                        if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
+                            L[i] = new_l; sumlx = slx; suml2 = sl2;
+                            ++n_changed;
+                        }
+                    }
+                }
+            }
+            if (!n_changed) {
+                break;
+            }
+        }
+        for (int i = 0; i < n; ++i) {
+            L[i] += nmax;
+        }
+        return sumlx / suml2;
+    }
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MAX(-nmax, MIN(nmax-1, l));
+        L[i] = l + nmax;
+    }
+    return 1/iscale;
+}
+
+static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
+        int ntry, float alpha) {
+    float min = x[0];
+    float max = x[0];
+    for (int i = 1; i < n; ++i) {
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+    }
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = 0;
+        return 0.f;
+    }
+    if (min > 0) min = 0;
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    for (int itry = 0; itry < ntry; ++itry) {
+        float sumlx = 0; int suml2 = 0;
+        bool did_change = false;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            if (l != L[i]) {
+                L[i] = l;
+                did_change = true;
+            }
+            sumlx += (x[i] - min)*l;
+            suml2 += l*l;
+        }
+        scale = sumlx/suml2;
+        float sum = 0;
+        for (int i = 0; i < n; ++i) {
+            sum += x[i] - scale*L[i];
+        }
+        min = alpha*min + (1 - alpha)*sum/n;
+        if (min > 0) min = 0;
+        iscale = 1/scale;
+        if (!did_change) break;
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights[0];
+    float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 1; i < n; ++i) {
+#else
+    for (int i = 1; i < n; ++i) {
+#endif
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) min = 0;
+    if (max == min) {
+        for (int i = 0; i < n; ++i) L[i] = 0;
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff * diff;
+        float w = weights[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights[i];
+            sum_l += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff * diff;
+                float w = weights[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
+    if (j < 4) {
+        *d = q[j] & 63; *m = q[j + 4] & 63;
+    } else {
+        *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+        *m = (q[j+4] >>  4) | ((q[j-0] >> 6) << 4);
+    }
+}
+
+//========================- 2-bit (de)-quantization
+
+void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float   weights[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+
+    const float q4scale = 15.f;
+
+    for (int i = 0; i < nb; i++) {
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]);
+            scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        if (max_scale > 0) {
+            float iscale = q4scale/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*scales[j]);
+                y[i].scales[j] = l;
+            }
+            y[i].d = GGML_FP32_TO_FP16(max_scale/q4scale);
+        } else {
+            for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+        if (max_min > 0) {
+            float iscale = q4scale/max_min;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int l = nearest_int(iscale*mins[j]);
+                y[i].scales[j] |= (l << 4);
+            }
+            y[i].dmin = GGML_FP32_TO_FP16(max_min/q4scale);
+        } else {
+            y[i].dmin = GGML_FP32_TO_FP16(0.f);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float d = GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF);
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4);
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int((x[16*j + ii] + dm)/d);
+                l = MAX(0, MIN(3, l));
+                L[16*j + ii] = l;
+            }
+        }
+
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        const uint8_t * q = x[i].qs;
+
+        int is = 0;
+        float dl, ml;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                uint8_t sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
+
+                sc = x[i].scales[is++];
+                dl = d * (sc & 0xF); ml = min * (sc >> 4);
+                for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
+
+                shift += 2;
+            }
+            q += 32;
+        }
+    }
+}
+
+static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
+        uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
+        float rmin, float rdelta, int nstep, bool use_mad) {
+    float min = x[0];
+    float max = x[0];
+    float sum_w = weights ? weights[0] : x[0]*x[0];
+    float sum_x = sum_w * x[0];
+#ifdef HAVE_BUGGY_APPLE_LINKER
+    // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
+    for (volatile int i = 1; i < n; ++i) {
+#else
+    for (int i = 1; i < n; ++i) {
+#endif
+        if (x[i] < min) min = x[i];
+        if (x[i] > max) max = x[i];
+        float w = weights ? weights[i] : x[i]*x[i];
+        sum_w += w;
+        sum_x += w * x[i];
+    }
+    if (min > 0) {
+        min = 0;
+    }
+    if (max <= min) {
+        memset(L, 0, n);
+        *the_min = -min;
+        return 0.f;
+    }
+    float iscale = nmax/(max - min);
+    float scale = 1/iscale;
+    float best_mad = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale*(x[i] - min));
+        L[i] = MAX(0, MIN(nmax, l));
+        float diff = scale * L[i] + min - x[i];
+        diff = use_mad ? fabsf(diff) : diff*diff;
+        float w = weights ? weights[i] : x[i]*x[i];
+        best_mad += w * diff;
+    }
+    if (nstep < 1) {
+        *the_min = -min;
+        return scale;
+    }
+    for (int is = 0; is <= nstep; ++is) {
+        iscale = (rmin + rdelta*is + nmax)/(max - min);
+        float sum_l = 0, sum_l2 = 0, sum_xl = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale*(x[i] - min));
+            l = MAX(0, MIN(nmax, l));
+            Laux[i] = l;
+            float w = weights ? weights[i] : x[i]*x[i];
+            sum_l  += w*l;
+            sum_l2 += w*l*l;
+            sum_xl += w*l*x[i];
+        }
+        float D = sum_w * sum_l2 - sum_l * sum_l;
+        if (D > 0) {
+            float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
+            float this_min   = (sum_l2 * sum_x - sum_l * sum_xl)/D;
+            if (this_min > 0) {
+                this_min = 0;
+                this_scale = sum_xl / sum_l2;
+            }
+            float mad = 0;
+            for (int i = 0; i < n; ++i) {
+                float diff = this_scale * Laux[i] + this_min - x[i];
+                diff = use_mad ? fabsf(diff) : diff*diff;
+                float w = weights ? weights[i] : x[i]*x[i];
+                mad += w * diff;
+            }
+            if (mad < best_mad) {
+                for (int i = 0; i < n; ++i) {
+                    L[i] = Laux[i];
+                }
+                best_mad = mad;
+                scale = this_scale;
+                min = this_min;
+            }
+        }
+    }
+    *the_min = -min;
+    return scale;
+}
+
+static float make_qp_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, const float * quant_weights) {
+    float max = 0;
+    for (int i = 0; i < n; ++i) {
+        max = MAX(max, x[i]);
+    }
+    if (!max) { // all zero
+        for (int i = 0; i < n; ++i) { L[i] = 0; }
+        return 0.f;
+    }
+    float iscale = nmax / max;
+    for (int i = 0; i < n; ++i) {
+        L[i] = nearest_int(iscale * x[i]);
+    }
+    float scale = 1/iscale;
+    float best_mse = 0;
+    for (int i = 0; i < n; ++i) {
+        float diff = x[i] - scale*L[i];
+        float w = quant_weights[i];
+        best_mse += w*diff*diff;
+    }
+    for (int is = -4; is <= 4; ++is) {
+        if (is == 0) continue;
+        float iscale_is = (0.1f*is + nmax)/max;
+        float scale_is = 1/iscale_is;
+        float mse = 0;
+        for (int i = 0; i < n; ++i) {
+            int l = nearest_int(iscale_is*x[i]);
+            l = MIN(nmax, l);
+            float diff = x[i] - scale_is*l;
+            float w = quant_weights[i];
+            mse += w*diff*diff;
+        }
+        if (mse < best_mse) {
+            best_mse = mse;
+            iscale = iscale_is;
+        }
+    }
+    float sumlx = 0;
+    float suml2 = 0;
+    for (int i = 0; i < n; ++i) {
+        int l = nearest_int(iscale * x[i]);
+        l = MIN(nmax, l);
+        L[i] = l;
+        float w = quant_weights[i];
+        sumlx += w*x[i]*l;
+        suml2 += w*l*l;
+    }
+    for (int itry = 0; itry < 5; ++itry) {
+        int n_changed = 0;
+        for (int i = 0; i < n; ++i) {
+            float w = quant_weights[i];
+            float slx = sumlx - w*x[i]*L[i];
+            float sl2 = suml2 - w*L[i]*L[i];
+            if (slx > 0 && sl2 > 0) {
+                int new_l = nearest_int(x[i] * sl2 / slx);
+                new_l = MIN(nmax, new_l);
+                if (new_l != L[i]) {
+                    slx += w*x[i]*new_l;
+                    sl2 += w*new_l*new_l;
+                    if (slx*slx*suml2 > sumlx*sumlx*sl2) {
+                        L[i] = new_l; sumlx = slx; suml2 = sl2;
+                        ++n_changed;
+                    }
+                }
+            }
+        }
+        if (!n_changed) {
+            break;
+        }
+    }
+    return sumlx/suml2;
+}
+
+static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restrict y, int k, const float * restrict quant_weights) {
+    GGML_ASSERT(quant_weights);
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+    const bool requantize = true;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[16];
+    float mins[QK_K/16];
+    float scales[QK_K/16];
+    float sw[QK_K/16];
+    float weight[16];
+    uint8_t Ls[QK_K/16], Lm[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+        memset(sw, 0, QK_K/16*sizeof(float));
+        float sumx2 = 0;
+        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
+        float sigma2 = sumx2/QK_K;
+        for (int j = 0; j < QK_K/16; ++j) {
+            const float * restrict qw = quant_weights + QK_K * i + 16*j;
+            for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
+            for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
+            scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+        }
+
+        float dm, mm;
+        dm  = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
+        mm  = make_qp_quants(QK_K/16, 15, mins,   Lm, sw);
+
+        y[i].d    = GGML_FP32_TO_FP16(dm);
+        y[i].dmin = GGML_FP32_TO_FP16(mm);
+        dm        = GGML_FP16_TO_FP32(y[i].d);
+        mm        = GGML_FP16_TO_FP32(y[i].dmin);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            y[i].scales[j] = Ls[j] | (Lm[j] << 4);
+        }
+
+        if (requantize) {
+            for (int j = 0; j < QK_K/16; ++j) {
+                const float d = dm * (y[i].scales[j] & 0xF);
+                if (!d) continue;
+                const float m = mm * (y[i].scales[j] >> 4);
+                for (int ii = 0; ii < 16; ++ii) {
+                    int l = nearest_int((x[16*j + ii] + m)/d);
+                    l = MAX(0, MIN(3, l));
+                    L[16*j + ii] = l;
+                }
+            }
+        }
+
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+
+        x += QK_K;
+    }
+}
+
+size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int64_t row = 0; row < nrow; ++row) {
+            quantize_row_q2_K_impl(src, (block_q2_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
+//========================= 3-bit (de)-quantization
+
+void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float scales[QK_K / 16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K/16; ++j) {
+            scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
+            float scale = fabsf(scales[j]);
+            if (scale > amax) {
+                amax = scale; max_scale = scales[j];
+            }
+        }
+
+        memset(y[i].scales, 0, 12);
+        if (max_scale) {
+            float iscale = -32.f/max_scale;
+            for (int j = 0; j < QK_K/16; ++j) {
+                int8_t l = nearest_int(iscale*scales[j]);
+                l = MAX(-32, MIN(31, l)) + 32;
+                if (j < 8) {
+                    y[i].scales[j] = l & 0xF;
+                } else {
+                    y[i].scales[j-8] |= ((l & 0xF) << 4);
+                }
+                l >>= 4;
+                y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+            }
+            y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        } else {
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+        }
+
+        int8_t sc;
+        for (int j = 0; j < QK_K/16; ++j) {
+            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+            float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+
+        memset(y[i].hmask, 0, QK_K/8);
+        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+        int m = 0;
+        uint8_t hm = 1;
+        for (int j = 0; j < QK_K; ++j) {
+            if (L[j] > 3) {
+                y[i].hmask[m] |= hm;
+                L[j] -= 4;
+            }
+            if (++m == QK_K/8) {
+                m = 0; hm <<= 1;
+            }
+        }
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    const uint32_t kmask1 = 0x03030303;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+
+    uint32_t aux[4];
+    const int8_t * scales = (const int8_t*)aux;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        uint8_t m = 1;
+
+        memcpy(aux, x[i].scales, 12);
+        uint32_t tmp = aux[2];
+        aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+
+        int is = 0;
+        float dl;
+        for (int n = 0; n < QK_K; n += 128) {
+            int shift = 0;
+            for (int j = 0; j < 4; ++j) {
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
+                }
+
+                dl = d_all * (scales[is++] - 32);
+                for (int l = 0; l < 16; ++l) {
+                    *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
+                }
+
+                shift += 2;
+                m <<= 1;
+            }
+            q += 32;
+        }
+
+    }
+}
+
+static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
+    assert(n_per_row % QK_K == 0);
+    const int nb = n_per_row / QK_K;
+
+    int8_t L[QK_K];
+    float scales[QK_K / 16];
+    float weight[16];
+    float sw[QK_K / 16];
+    int8_t Ls[QK_K / 16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float sumx2 = 0;
+        for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K * i + 16*j;
+                for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j+l]*x[16*j+l]);
+            } else {
+                for (int l = 0; l < 16; ++l) weight[l] = x[16*j+l]*x[16*j+l];
+            }
+            float sumw = 0;
+            for (int l = 0; l < 16; ++l) sumw += weight[l];
+            sw[j] = sumw;
+
+            scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight);
+
+        }
+
+        memset(y[i].scales, 0, 12);
+
+        float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw);
+        for (int j = 0; j < QK_K/16; ++j) {
+            int l = Ls[j];
+            if (j < 8) {
+                y[i].scales[j] = l & 0xF;
+            } else {
+                y[i].scales[j-8] |= ((l & 0xF) << 4);
+            }
+            l >>= 4;
+            y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
+        }
+        y[i].d = GGML_FP32_TO_FP16(d_block);
+
+        int8_t sc;
+        for (int j = 0; j < QK_K/16; ++j) {
+            sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
+            sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
+            float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-4, MIN(3, l));
+                L[16*j + ii] = l + 4;
+            }
+        }
+
+        memset(y[i].hmask, 0, QK_K/8);
+        // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc.
+        int m = 0;
+        uint8_t hm = 1;
+        for (int j = 0; j < QK_K; ++j) {
+            if (L[j] > 3) {
+                y[i].hmask[m] |= hm;
+                L[j] -= 4;
+            }
+            if (++m == QK_K/8) {
+                m = 0; hm <<= 1;
+            }
+        }
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
+            }
+        }
+
+        x += QK_K;
+    }
+}
+
+size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int64_t row = 0; row < nrow; ++row) {
+            quantize_row_q3_K_impl(src, (block_q3_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
+// ====================== 4-bit (de)-quantization
+
+void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[32];
+    float   weights[32];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+
+    for (int i = 0; i < nb; i++) {
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+            }
+        }
+
+        uint8_t * q = y[i].qs;
+        for (int j = 0; j < QK_K; j += 64) {
+            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+            q += 32;
+        }
+
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * q = x[i].qs;
+
+        const float d   = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l]  >> 4) - m2;
+            q += 32; is += 2;
+        }
+    }
+}
+
+static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+    assert(n_per_row % QK_K == 0);
+    const int64_t nb = n_per_row / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[32];
+    uint8_t Ls[QK_K/32];
+    uint8_t Lm[QK_K/32];
+    float   weights[32];
+    float   sw[QK_K/32];
+    float   mins[QK_K/32];
+    float   scales[QK_K/32];
+
+    for (int i = 0; i < nb; i++) {
+
+        float sum_x2 = 0;
+        for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
+        float sigma2 = 2*sum_x2/QK_K;
+        float av_x = sqrtf(sigma2);
+
+        for (int j = 0; j < QK_K/32; ++j) {
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*i + 32*j;
+                for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
+            } else {
+                for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            }
+            float sumw = 0;
+            for (int l = 0; l < 32; ++l) sumw += weights[l];
+            sw[j] = sumw;
+            scales[j] = make_qkx3_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+        }
+
+        float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
+        float m_block = make_qp_quants(QK_K/32, 63, mins,   Lm, sw);
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = Ls[j];
+            uint8_t lm = Lm[j];
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(d_block);
+        y[i].dmin = GGML_FP32_TO_FP16(m_block);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(15, l));
+                L[32*j + ii] = l;
+            }
+        }
+        uint8_t * q = y[i].qs;
+        for (int j = 0; j < QK_K; j += 64) {
+            for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4);
+            q += 32;
+        }
+
+        x += QK_K;
+
+    }
+}
+
+size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int64_t row = 0; row < nrow; ++row) {
+            quantize_row_q4_K_impl(src, (block_q4_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
+// ====================== 5-bit (de)-quantization
+
+void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    uint8_t L[QK_K];
+    float mins[QK_K/32];
+    float scales[QK_K/32];
+    float weights[32];
+    uint8_t Laux[32];
+
+    for (int i = 0; i < nb; i++) {
+        float max_scale = 0; // as we are deducting the min, scales are always positive
+        float max_min = 0;
+        for (int j = 0; j < QK_K/32; ++j) {
+            //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f);
+            float sum_x2 = 0;
+            for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l];
+            float av_x = sqrtf(sum_x2/32);
+            for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false);
+            float scale = scales[j];
+            if (scale > max_scale) {
+                max_scale = scale;
+            }
+            float min = mins[j];
+            if (min > max_min) {
+                max_min = min;
+            }
+        }
+
+        float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
+        float inv_min   = max_min   > 0 ? 63.f/max_min   : 0.f;
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = nearest_int(inv_scale*scales[j]);
+            uint8_t lm = nearest_int(inv_min*mins[j]);
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(max_scale/63.f);
+        y[i].dmin = GGML_FP32_TO_FP16(max_min/63.f);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(31, l));
+                L[32*j + ii] = l;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        uint8_t m1 = 1, m2 = 2;
+        for (int n = 0; n < QK_K; n += 64) {
+            for (int j = 0; j < 32; ++j) {
+                int l1 = L[n + j];
+                if (l1 > 15) {
+                    l1 -= 16; qh[j] |= m1;
+                }
+                int l2 = L[n + j + 32];
+                if (l2 > 15) {
+                    l2 -= 16; qh[j] |= m2;
+                }
+                ql[j] = l1 | (l2 << 4);
+            }
+            m1 <<= 2; m2 <<= 2;
+            ql += 32;
+        }
+
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        const uint8_t * ql = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float min = GGML_FP16_TO_FP32(x[i].dmin);
+
+        int is = 0;
+        uint8_t sc, m;
+        uint8_t u1 = 1, u2 = 2;
+        for (int j = 0; j < QK_K; j += 64) {
+            get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
+            const float d1 = d * sc; const float m1 = min * m;
+            get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
+            const float d2 = d * sc; const float m2 = min * m;
+            for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
+            for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l]  >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
+            ql += 32; is += 2;
+            u1 <<= 2; u2 <<= 2;
+        }
+    }
+}
+
+static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+    assert(n_per_row % QK_K == 0);
+    const int64_t nb = n_per_row / QK_K;
+
+    uint8_t L[QK_K];
+    uint8_t Laux[32];
+    uint8_t Ls[QK_K/32];
+    uint8_t Lm[QK_K/32];
+    float   mins[QK_K/32];
+    float   scales[QK_K/32];
+    float   sw[QK_K/32];
+    float   weights[32];
+
+    for (int i = 0; i < nb; i++) {
+
+        float sum_x2 = 0;
+        for (int l = 0; l < QK_K; ++l) sum_x2 += x[l] * x[l];
+        float sigma2 = 2*sum_x2/QK_K;
+        float av_x = sqrtf(sigma2);
+
+        for (int j = 0; j < QK_K/32; ++j) {
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*i + 32*j;
+                for (int l = 0; l < 32; ++l) weights[l] = qw[l] * sqrtf(sigma2 + x[32*j + l]*x[32*j + l]);
+            } else {
+                for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]);
+            }
+            float sumw = 0;
+            for (int l = 0; l < 32; ++l) sumw += weights[l];
+            sw[j] = sumw;
+
+            scales[j] = make_qkx3_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
+        }
+
+        float d_block = make_qp_quants(QK_K/32, 63, scales, Ls, sw);
+        float m_block = make_qp_quants(QK_K/32, 63, mins,   Lm, sw);
+
+        for (int j = 0; j < QK_K/32; ++j) {
+            uint8_t ls = Ls[j];
+            uint8_t lm = Lm[j];
+            ls = MIN(63, ls);
+            lm = MIN(63, lm);
+            if (j < 4) {
+                y[i].scales[j] = ls;
+                y[i].scales[j+4] = lm;
+            } else {
+                y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
+                y[i].scales[j-4] |= ((ls >> 4) << 6);
+                y[i].scales[j-0] |= ((lm >> 4) << 6);
+            }
+        }
+        y[i].d = GGML_FP32_TO_FP16(d_block);
+        y[i].dmin = GGML_FP32_TO_FP16(m_block);
+
+        uint8_t sc, m;
+        for (int j = 0; j < QK_K/32; ++j) {
+            get_scale_min_k4(j, y[i].scales, &sc, &m);
+            const float d = GGML_FP16_TO_FP32(y[i].d) * sc;
+            if (!d) continue;
+            const float dm = GGML_FP16_TO_FP32(y[i].dmin) * m;
+            for (int ii = 0; ii < 32; ++ii) {
+                int l = nearest_int((x[32*j + ii] + dm)/d);
+                l = MAX(0, MIN(31, l));
+                L[32*j + ii] = l;
+            }
+        }
+
+        uint8_t * restrict qh = y[i].qh;
+        uint8_t * restrict ql = y[i].qs;
+        memset(qh, 0, QK_K/8);
+
+        uint8_t m1 = 1, m2 = 2;
+        for (int n = 0; n < QK_K; n += 64) {
+            for (int j = 0; j < 32; ++j) {
+                int l1 = L[n + j];
+                if (l1 > 15) {
+                    l1 -= 16; qh[j] |= m1;
+                }
+                int l2 = L[n + j + 32];
+                if (l2 > 15) {
+                    l2 -= 16; qh[j] |= m2;
+                }
+                ql[j] = l1 | (l2 << 4);
+            }
+            m1 <<= 2; m2 <<= 2;
+            ql += 32;
+        }
+
+        x += QK_K;
+
+    }
+}
+
+size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int64_t row = 0; row < nrow; ++row) {
+            quantize_row_q5_K_impl(src, (block_q5_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
+// ====================== 6-bit (de)-quantization
+
+void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    int8_t L[QK_K];
+    float   scales[QK_K/16];
+
+    for (int i = 0; i < nb; i++) {
+
+        float max_scale = 0;
+        float max_abs_scale = 0;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+
+            const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
+            scales[ib] = scale;
+
+            const float abs_scale = fabsf(scale);
+            if (abs_scale > max_abs_scale) {
+                max_abs_scale = abs_scale;
+                max_scale = scale;
+            }
+
+        }
+
+        if (max_abs_scale < GROUP_MAX_EPS) {
+            memset(&y[i], 0, sizeof(block_q6_K));
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+            x += QK_K;
+            continue;
+        }
+
+        float iscale = -128.f/max_scale;
+        y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+        }
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-32, MIN(31, l));
+                L[16*j + ii] = l + 32;
+            }
+        }
+
+        uint8_t * restrict ql = y[i].ql;
+        uint8_t * restrict qh = y[i].qh;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                const uint8_t q1 = L[j + l +  0] & 0xF;
+                const uint8_t q2 = L[j + l + 32] & 0xF;
+                const uint8_t q3 = L[j + l + 64] & 0xF;
+                const uint8_t q4 = L[j + l + 96] & 0xF;
+                ql[l+ 0] = q1 | (q3 << 4);
+                ql[l+32] = q2 | (q4 << 4);
+                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+            }
+            ql += 64;
+            qh += 32;
+        }
+
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict ql = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict sc = x[i].scales;
+
+        for (int n = 0; n < QK_K; n += 128) {
+            for (int l = 0; l < 32; ++l) {
+                int is = l/16;
+                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+                y[l +  0] = d * sc[is + 0] * q1;
+                y[l + 32] = d * sc[is + 2] * q2;
+                y[l + 64] = d * sc[is + 4] * q3;
+                y[l + 96] = d * sc[is + 6] * q4;
+            }
+            y  += 128;
+            ql += 64;
+            qh += 32;
+            sc += 8;
+        }
+    }
+}
+
+static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
+    assert(n_per_row % QK_K == 0);
+    const int64_t nb = n_per_row / QK_K;
+
+    int8_t L[QK_K];
+    float   scales[QK_K/16];
+    //float   weights[16];
+
+    for (int i = 0; i < nb; i++) {
+
+        //float sum_x2 = 0;
+        //for (int j = 0; j < QK_K; ++j) sum_x2 += x[j]*x[j];
+        //float sigma2 = sum_x2/QK_K;
+
+        float max_scale = 0;
+        float max_abs_scale = 0;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+
+            float scale;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*i + 16*ib;
+                //for (int j = 0; j < 16; ++j) weights[j] = qw[j] * sqrtf(sigma2 + x[16*ib + j]*x[16*ib + j]);
+                //scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, weights);
+                scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, qw);
+            } else {
+                scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1, NULL);
+            }
+            scales[ib] = scale;
+
+            const float abs_scale = fabsf(scale);
+            if (abs_scale > max_abs_scale) {
+                max_abs_scale = abs_scale;
+                max_scale = scale;
+            }
+
+        }
+
+        if (max_abs_scale < GROUP_MAX_EPS) {
+            memset(&y[i], 0, sizeof(block_q6_K));
+            y[i].d = GGML_FP32_TO_FP16(0.f);
+            x += QK_K;
+            continue;
+        }
+
+        float iscale = -128.f/max_scale;
+        y[i].d = GGML_FP32_TO_FP16(1/iscale);
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
+        }
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j];
+            if (!d) {
+                continue;
+            }
+            for (int ii = 0; ii < 16; ++ii) {
+                int l = nearest_int(x[16*j + ii]/d);
+                l = MAX(-32, MIN(31, l));
+                L[16*j + ii] = l + 32;
+            }
+        }
+
+        uint8_t * restrict ql = y[i].ql;
+        uint8_t * restrict qh = y[i].qh;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                const uint8_t q1 = L[j + l +  0] & 0xF;
+                const uint8_t q2 = L[j + l + 32] & 0xF;
+                const uint8_t q3 = L[j + l + 64] & 0xF;
+                const uint8_t q4 = L[j + l + 96] & 0xF;
+                ql[l+ 0] = q1 | (q3 << 4);
+                ql[l+32] = q2 | (q4 << 4);
+                qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
+            }
+            ql += 64;
+            qh += 32;
+        }
+
+        x += QK_K;
+
+    }
+}
+
+size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
+    if (!quant_weights) {
+        quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
+    }
+    else {
+        char * qrow = (char *)dst;
+        for (int64_t row = 0; row < nrow; ++row) {
+            quantize_row_q6_K_impl(src, (block_q6_K*)qrow, n_per_row, quant_weights);
+            src += n_per_row;
+            qrow += row_size;
+        }
+    }
+    return nrow * row_size;
+}
+
+static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+    static_assert(QK4_0 == 32, "QK4_0 must be 32");
+
+    if (!quant_weights) {
+        quantize_row_q4_0_ref(x, y, n_per_row);
+        return;
+    }
+
+    float weight[QK4_0];
+    int8_t L[QK4_0];
+
+    float sum_x2 = 0;
+    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+    float sigma2 = sum_x2/n_per_row;
+
+    const int64_t nb = n_per_row/QK4_0;
+    for (int ib = 0; ib < nb; ++ib) {
+        const float * xb = x + QK4_0 * ib;
+        const float * qw = quant_weights + QK4_0 * ib;
+        for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+        float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
+        y[ib].d = GGML_FP32_TO_FP16(d);
+        for (int j = 0; j < 16; ++j) {
+            y[ib].qs[j] = L[j] | (L[j+16] << 4);
+        }
+    }
+}
+
+size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
+        return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
+    }
+    size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_q4_0_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += row_size;
+    }
+    return nrow * row_size;
+}
+
+static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
+    static_assert(QK4_1 == 32, "QK4_1 must be 32");
+
+    if (!quant_weights) {
+        quantize_row_q4_1_ref(x, y, n_per_row);
+        return;
+    }
+
+    float weight[QK4_1];
+    uint8_t L[QK4_1], Laux[QK4_1];
+
+    float sum_x2 = 0;
+    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+    float sigma2 = sum_x2/n_per_row;
+
+    const int64_t nb = n_per_row/QK4_1;
+    for (int ib = 0; ib < nb; ++ib) {
+        const float * xb = x + QK4_1 * ib;
+        const float * qw = quant_weights + QK4_1 * ib;
+        for (int j = 0; j < QK4_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+        float min;
+        float d = make_qkx3_quants(QK4_1, 15, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
+        y[ib].d = GGML_FP32_TO_FP16(d);
+        y[ib].m = GGML_FP32_TO_FP16(-min);
+        for (int j = 0; j < 16; ++j) {
+            y[ib].qs[j] = L[j] | (L[j+16] << 4);
+        }
+    }
+}
+
+size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
+        return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
+    }
+    size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_q4_1_impl(src, (block_q4_1*)qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += row_size;
+    }
+    return nrow * row_size;
+}
+
+static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restrict y, int64_t n_per_row, const float * quant_weights) {
+    static_assert(QK5_0 == 32, "QK5_0 must be 32");
+
+    if (!quant_weights) {
+        quantize_row_q5_0_ref(x, y, n_per_row);
+        return;
+    }
+
+    float weight[QK5_0];
+    int8_t L[QK5_0];
+
+    float sum_x2 = 0;
+    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+    float sigma2 = sum_x2/n_per_row;
+
+    const int64_t nb = n_per_row/QK5_0;
+    for (int ib = 0; ib < nb; ++ib) {
+        const float * xb = x + QK5_0 * ib;
+        const float * qw = quant_weights + QK5_0 * ib;
+        for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+        float d = make_qx_quants(QK5_0, 16, xb, L, 1, weight);
+        y[ib].d = GGML_FP32_TO_FP16(d);
+
+        uint32_t qh = 0;
+
+        for (int j = 0; j < 16; ++j) {
+            const uint8_t xi0 = L[j];
+            const uint8_t xi1 = L[j+16];
+            y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+        }
+
+        memcpy(&y[ib].qh, &qh, sizeof(qh));
+    }
+}
+
+size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
+        return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
+    }
+    size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_q5_0_impl(src, (block_q5_0*)qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += row_size;
+    }
+    return nrow * row_size;
+}
+
+static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restrict y, int64_t n_per_row, const float * quant_weights) {
+    static_assert(QK5_1 == 32, "QK5_1 must be 32");
+
+    if (!quant_weights) {
+        quantize_row_q5_1_ref(x, y, n_per_row);
+        return;
+    }
+
+    float weight[QK5_1];
+    uint8_t L[QK5_1], Laux[QK5_1];
+
+    float sum_x2 = 0;
+    for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
+    float sigma2 = sum_x2/n_per_row;
+
+    const int64_t nb = n_per_row/QK5_1;
+    for (int ib = 0; ib < nb; ++ib) {
+        const float * xb = x + QK5_1 * ib;
+        const float * qw = quant_weights + QK5_1 * ib;
+        for (int j = 0; j < QK5_1; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+        float min;
+        float d = make_qkx3_quants(QK5_1, 31, xb, weight, L, &min, Laux, -0.9f, 0.05f, 36, false);
+        y[ib].d = GGML_FP32_TO_FP16(d);
+        y[ib].m = GGML_FP32_TO_FP16(-min);
+
+        uint32_t qh = 0;
+        for (int j = 0; j < 16; ++j) {
+            const uint8_t xi0 = L[j];
+            const uint8_t xi1 = L[j+16];
+            y[ib].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
+            // get the 5-th bit and store it in qh at the right position
+            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+        }
+        memcpy(&y[ib].qh, &qh, sizeof(qh));
+    }
+}
+
+size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    if (!quant_weights) {
+        quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
+        return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
+    }
+    size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_q5_1_impl(src, (block_q5_1*)qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += row_size;
+    }
+    return nrow * row_size;
+}
+
+size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    (void)quant_weights; // not used
+    const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
+    quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * row_size;
+}
+
+// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
+
+void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int64_t i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK_K; j++) {
+            const float v = x[j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        // 5 elements per byte, along 32 bytes
+        for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) {
+            for (size_t m = 0; m < 32; ++m) {
+                uint8_t q = 0;
+                for (size_t n = 0; n < 5; ++n) {
+                    int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2
+                    q *= 3;
+                    q += xi;
+                }
+                // ceiling division (243 == pow(3, 5))
+                q = ((uint16_t)q * 256 + (243 - 1)) / 243;
+                y[i].qs[j + m] = q;
+            }
+            x += 5*32;
+        }
+        // along 16 bytes
+        for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) {
+            for (size_t m = 0; m < 16; ++m) {
+                uint8_t q = 0;
+                for (size_t n = 0; n < 5; ++n) {
+                    int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2
+                    q *= 3;
+                    q += xi;
+                }
+                // ceiling division (243 == pow(3, 5))
+                q = ((uint16_t)q * 256 + (243 - 1)) / 243;
+                y[i].qs[j + m] = q;
+            }
+            x += 5*16;
+        }
+        // 4 elements per byte
+        for (size_t j = 0; j < sizeof(y->qh); ++j) {
+            uint8_t q = 0;
+            for (size_t m = 0; m < 4; ++m) {
+                // -1, 0, 1 -> 0, 1, 2
+                int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1;
+                q *= 3;
+                q += xi;
+            }
+            // shift the first value to the most significant trit
+            q *= 3;
+            // ceiling division (243 == pow(3, 5))
+            q = ((uint16_t)q * 256 + (243 - 1)) / 243;
+            y[i].qh[j] = q;
+        }
+        x += 4*sizeof(y->qh);
+    }
+}
+
+void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int64_t i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < QK_K; j++) {
+            const float v = x[j];
+            amax = MAX(amax, fabsf(v));
+        }
+
+        const float d = amax;
+        const float id = d ? 1.0f/d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (size_t j = 0; j < sizeof(y->qs); j += 32) {
+            for (size_t m = 0; m < 32; ++m) {
+                uint8_t q = 0;
+                for (size_t n = 0; n < 4; ++n) {
+                    // -1, 0, 1 -> 0, 1, 2
+                    int xi = lroundf(x[m + n*32] * id) + 1;
+                    q += (xi & 3) << (2*n);
+                }
+                y[i].qs[j + m] = q;
+            }
+            x += 4*32;
+        }
+    }
+}
+
+size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    (void)quant_weights; // not used
+    const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row);
+    quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * row_size;
+}
+
+size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    (void)quant_weights; // not used
+    const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row);
+    quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * row_size;
+}
+
+void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
+
+    for (int64_t i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) {
+            for (size_t n = 0; n < 5; ++n) {
+                for (size_t m = 0; m < 32; ++m) {
+                    uint8_t q = x[i].qs[j + m] * pow3[n];
+                    int16_t xi = ((uint16_t) q * 3) >> 8;
+                    *y++ = (float) (xi - 1) * d;
+                }
+            }
+        }
+        for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) {
+            for (size_t n = 0; n < 5; ++n) {
+                for (size_t m = 0; m < 16; ++m) {
+                    uint8_t q = x[i].qs[j + m] * pow3[n];
+                    int16_t xi = ((uint16_t) q * 3) >> 8;
+                    *y++ = (float) (xi - 1) * d;
+                }
+            }
+        }
+
+        for (size_t n = 0; n < 4; ++n) {
+            for (size_t j = 0; j < sizeof(x->qh); ++j) {
+                uint8_t q = x[i].qh[j] * pow3[n];
+                int16_t xi = ((uint16_t) q * 3) >> 8;
+                *y++ = (float) (xi - 1) * d;
+            }
+        }
+    }
+}
+
+void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int64_t i = 0; i < nb; ++i) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (size_t j = 0; j < sizeof(x->qs); j += 32) {
+            for (size_t l = 0; l < 4; ++l) {
+                for (size_t m = 0; m < 32; ++m) {
+                    int8_t q = (x[i].qs[j + m] >> (l*2)) & 3;
+                    *y++ = (float) (q - 1) * d;
+                }
+            }
+        }
+    }
+}
+
+// ====================== "True" 2-bit (de)-quantization
+
+void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    uint32_t aux32[2];
+    const uint8_t * aux8 = (const uint8_t *)aux32;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(aux32, x[i].qs + 4*ib32, 2*sizeof(uint32_t));
+            const float db = d * (0.5f + (aux32[1] >> 28)) * 0.25f;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]);
+                const uint8_t  signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127];
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = db * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+        }
+    }
+}
+
+// ====================== 2.3125 bpw (de)-quantization
+
+void dequantize_row_iq2_xs(const block_iq2_xs * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    float db[2];
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
+            db[1] = d * (0.5f + (x[i].scales[ib32] >>  4)) * 0.25f;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (x[i].qs[4*ib32 + l] & 511));
+                const uint8_t  signs = ksigns_iq2xs[x[i].qs[4*ib32 + l] >> 9];
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = db[l/2] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+        }
+    }
+}
+
+// ====================== 2.5625 bpw (de)-quantization
+
+void dequantize_row_iq2_s(const block_iq2_s * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    float db[2];
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+        const uint8_t * signs = qs + QK_K/8;
+
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            db[0] = d * (0.5f + (x[i].scales[ib32] & 0xf)) * 0.25f;
+            db[1] = d * (0.5f + (x[i].scales[ib32] >>  4)) * 0.25f;
+            for (int l = 0; l < 4; ++l) {
+                const float dl = db[l/2];
+                const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300)));
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qs += 4;
+            signs += 4;
+        }
+    }
+}
+
+// ====================== 3.0625 bpw (de)-quantization
+
+void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    uint32_t aux32;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * scales_and_signs = qs + QK_K/4;
+
+        for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
+            memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t));
+            const float db = d * (0.5f + (aux32 >> 28)) * 0.5f;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t  signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
+                const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + qs[2*l+0]);
+                const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + qs[2*l+1]);
+                for (int j = 0; j < 4; ++j) {
+                    y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qs += 8;
+        }
+    }
+}
+
+// ====================== 3.3125 bpw (de)-quantization
+
+void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+        const uint8_t * signs = x[i].signs;
+
+        for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
+            const float db1 = d * (1 + 2*(x[i].scales[ib32/2] & 0xf));
+            const float db2 = d * (1 + 2*(x[i].scales[ib32/2] >>  4));
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qs += 8;
+            signs += 4;
+            for (int l = 0; l < 4; ++l) {
+                const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)));
+                const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)));
+                for (int j = 0; j < 4; ++j) {
+                    y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
+                    y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
+                }
+                y += 8;
+            }
+            qh += 2;
+            qs += 8;
+            signs += 4;
+        }
+    }
+}
+
+// ====================== 1.5625 bpw (de)-quantization
+
+void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const uint8_t  * qs = x[i].qs;
+        const uint16_t * qh = x[i].qh;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float dl = d * (2*((qh[ib] >> 12) & 7) + 1);
+            const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA;
+            for (int l = 0; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = dl * (grid[j] + delta);
+                }
+                y += 8;
+            }
+            qs += 4;
+        }
+    }
+}
+
+void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    float delta[4];
+    uint16_t idx[4];
+
+    iq1m_scale_t scale;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint16_t * sc = (const uint16_t *)x[i].scales;
+        scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+        const float d = GGML_FP16_TO_FP32(scale.f16);
+
+        const uint8_t * qs = x[i].qs;
+        const uint8_t * qh = x[i].qh;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1);
+            const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1);
+
+            idx[0] = qs[0] | ((qh[0] << 8) & 0x700);
+            idx[1] = qs[1] | ((qh[0] << 4) & 0x700);
+            idx[2] = qs[2] | ((qh[1] << 8) & 0x700);
+            idx[3] = qs[3] | ((qh[1] << 4) & 0x700);
+            delta[0] = qh[0] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
+            delta[1] = qh[0] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
+            delta[2] = qh[1] & 0x08 ? -IQ1S_DELTA : IQ1S_DELTA;
+            delta[3] = qh[1] & 0x80 ? -IQ1S_DELTA : IQ1S_DELTA;
+            for (int l = 0; l < 2; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = dl1 * (grid[j] + delta[l]);
+                }
+                y += 8;
+            }
+            for (int l = 2; l < 4; ++l) {
+                const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
+                for (int j = 0; j < 8; ++j) {
+                    y[j] = dl2 * (grid[j] + delta[l]);
+                }
+                y += 8;
+            }
+            qs += 4;
+            qh += 2;
+        }
+    }
+}
+
+static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
+
+void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK4_NL == 0);
+    const int64_t nb = k / QK4_NL;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * qs = x[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        for (int j = 0; j < QK4_NL/2; ++j) {
+            y[j+       0] = d * kvalues_iq4nl[qs[j] & 0xf];
+            y[j+QK4_NL/2] = d * kvalues_iq4nl[qs[j] >>  4];
+        }
+        y  += QK4_NL;
+        qs += QK4_NL/2;
+    }
+}
+
+void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        const uint8_t * qs = x[i].qs;
+
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const int ls = ((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4);
+            const float dl = d * (ls - 32);
+            for (int j = 0; j < 16; ++j) {
+                y[j+ 0] = dl * kvalues_iq4nl[qs[j] & 0xf];
+                y[j+16] = dl * kvalues_iq4nl[qs[j] >>  4];
+            }
+            y  += 32;
+            qs += 16;
+        }
+    }
+}
+
+//===================================== Q8_K ==============================================
+
+void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+
+        float max = 0;
+        float amax = 0;
+        for (int j = 0; j < QK_K; ++j) {
+            float ax = fabsf(x[j]);
+            if (ax > amax) {
+                amax = ax; max = x[j];
+            }
+        }
+        if (!amax) {
+            y[i].d = 0;
+            memset(y[i].qs, 0, QK_K);
+            x += QK_K;
+            continue;
+        }
+        //const float iscale = -128.f/max;
+        // We need this change for IQ2_XXS, else the AVX implementation becomes very awkward
+        const float iscale = -127.f/max;
+        for (int j = 0; j < QK_K; ++j) {
+            int v = nearest_int(iscale*x[j]);
+            y[i].qs[j] = MIN(127, v);
+        }
+        for (int j = 0; j < QK_K/16; ++j) {
+            int sum = 0;
+            for (int ii = 0; ii < 16; ++ii) {
+                sum += y[i].qs[j*16 + ii];
+            }
+            y[i].bsums[j] = sum;
+        }
+        y[i].d = 1/iscale;
+        x += QK_K;
+    }
+}
+
+void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+
+    for (int i = 0; i < nb; i++) {
+        for (int j = 0; j < QK_K; ++j) {
+            *y++ = x[i].d * x[i].qs[j];
+        }
+    }
+}
+
+// ================================ IQ2 quantization =============================================
+
+typedef struct {
+    uint64_t * grid;
+    int      * map;
+    uint16_t * neighbours;
+} iq2_entry_t;
+
+static iq2_entry_t iq2_data[4] = {
+    {NULL, NULL, NULL},
+    {NULL, NULL, NULL},
+    {NULL, NULL, NULL},
+    {NULL, NULL, NULL},
+};
+
+static inline int iq2_data_index(enum ggml_type type) {
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
+    return type == GGML_TYPE_IQ2_XXS ? 0 :
+           type == GGML_TYPE_IQ2_XS  ? 1 :
+           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 2 : 3;
+}
+
+static inline int iq2_grid_size(enum ggml_type type) {
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
+    return type == GGML_TYPE_IQ2_XXS ? 256 :
+           type == GGML_TYPE_IQ2_XS  ? 512 :
+           type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? NGRID_IQ1S : 1024;
+}
+
+static int iq2_compare_func(const void * left, const void * right) {
+    const int * l = (const int *)left;
+    const int * r = (const int *)right;
+    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+void iq2xs_init_impl(enum ggml_type type) {
+    const int gindex = iq2_data_index(type);
+    const int grid_size = iq2_grid_size(type);
+    if (iq2_data[gindex].grid) {
+        return;
+    }
+    static const uint16_t kgrid_2bit_256[256] = {
+            0,     2,     5,     8,    10,    17,    20,    32,    34,    40,    42,    65,    68,    80,    88,    97,
+          100,   128,   130,   138,   162,   257,   260,   272,   277,   320,   388,   408,   512,   514,   546,   642,
+         1025,  1028,  1040,  1057,  1060,  1088,  1090,  1096,  1120,  1153,  1156,  1168,  1188,  1280,  1282,  1288,
+         1312,  1350,  1385,  1408,  1425,  1545,  1552,  1600,  1668,  1700,  2048,  2053,  2056,  2068,  2088,  2113,
+         2116,  2128,  2130,  2184,  2308,  2368,  2562,  2580,  4097,  4100,  4112,  4129,  4160,  4192,  4228,  4240,
+         4245,  4352,  4360,  4384,  4432,  4442,  4480,  4644,  4677,  5120,  5128,  5152,  5157,  5193,  5248,  5400,
+         5474,  5632,  5654,  6145,  6148,  6160,  6208,  6273,  6400,  6405,  6560,  6737,  8192,  8194,  8202,  8260,
+         8289,  8320,  8322,  8489,  8520,  8704,  8706,  9217,  9220,  9232,  9280,  9302,  9472,  9537,  9572,  9872,
+        10248, 10272, 10388, 10820, 16385, 16388, 16400, 16408, 16417, 16420, 16448, 16456, 16470, 16480, 16513, 16516,
+        16528, 16640, 16672, 16737, 16768, 16773, 16897, 16912, 16968, 16982, 17000, 17408, 17416, 17440, 17536, 17561,
+        17682, 17700, 17920, 18433, 18436, 18448, 18496, 18501, 18688, 18776, 18785, 18818, 19013, 19088, 20480, 20488,
+        20497, 20505, 20512, 20608, 20616, 20740, 20802, 20900, 21137, 21648, 21650, 21770, 22017, 22100, 22528, 22545,
+        22553, 22628, 22848, 23048, 24580, 24592, 24640, 24680, 24832, 24917, 25112, 25184, 25600, 25605, 25872, 25874,
+        25988, 26690, 32768, 32770, 32778, 32833, 32898, 33028, 33048, 33088, 33297, 33793, 33796, 33808, 33813, 33856,
+        33888, 34048, 34118, 34196, 34313, 34368, 34400, 34818, 35076, 35345, 36868, 36880, 36900, 36928, 37025, 37142,
+        37248, 37445, 37888, 37922, 37956, 38225, 39041, 39200, 40962, 41040, 41093, 41225, 41472, 42008, 43088, 43268,
+    };
+    static const uint16_t kgrid_2bit_512[512] = {
+            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
+           73,    80,    82,    85,    88,    97,   100,   128,   130,   133,   136,   145,   148,   153,   160,   257,
+          260,   262,   265,   272,   274,   277,   280,   282,   289,   292,   320,   322,   325,   328,   337,   340,
+          352,   360,   385,   388,   400,   512,   514,   517,   520,   529,   532,   544,   577,   580,   592,   597,
+          640,   650,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1088,  1090,  1093,  1096,
+         1105,  1108,  1110,  1120,  1153,  1156,  1168,  1280,  1282,  1285,  1288,  1297,  1300,  1312,  1345,  1348,
+         1360,  1377,  1408,  1537,  1540,  1552,  1574,  1600,  1602,  1668,  2048,  2050,  2053,  2056,  2058,  2065,
+         2068,  2080,  2085,  2113,  2116,  2128,  2136,  2176,  2208,  2218,  2305,  2308,  2320,  2368,  2433,  2441,
+         2560,  2592,  2600,  2710,  2720,  4097,  4100,  4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4160,
+         4162,  4165,  4168,  4177,  4180,  4192,  4202,  4225,  4228,  4240,  4352,  4354,  4357,  4360,  4369,  4372,
+         4384,  4417,  4420,  4432,  4480,  4500,  4502,  4609,  4612,  4614,  4624,  4672,  4704,  5120,  5122,  5125,
+         5128,  5137,  5140,  5152,  5185,  5188,  5193,  5200,  5220,  5248,  5377,  5380,  5392,  5440,  5632,  5652,
+         5705,  6145,  6148,  6160,  6162,  6208,  6228,  6278,  6400,  6405,  6502,  6737,  6825,  8192,  8194,  8197,
+         8200,  8202,  8209,  8212,  8224,  8257,  8260,  8272,  8320,  8352,  8449,  8452,  8464,  8512,  8520,  8549,
+         8704,  8738,  8832,  8872,  9217,  9220,  9232,  9257,  9280,  9472,  9537,  9554,  9625,  9729,  9754,  9894,
+        10240, 10248, 10250, 10272, 10325, 10376, 10402, 10600, 10640, 10760, 10784, 10882, 10888, 10890, 16385, 16388,
+        16390, 16393, 16400, 16402, 16405, 16408, 16417, 16420, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16480,
+        16485, 16513, 16516, 16528, 16640, 16642, 16645, 16648, 16657, 16660, 16672, 16705, 16708, 16720, 16768, 16773,
+        16802, 16897, 16900, 16912, 16914, 16937, 16960, 17408, 17410, 17413, 17416, 17425, 17428, 17433, 17440, 17473,
+        17476, 17488, 17536, 17556, 17665, 17668, 17680, 17700, 17728, 17818, 17920, 17930, 17988, 18000, 18433, 18436,
+        18448, 18496, 18501, 18516, 18530, 18688, 18705, 18756, 18768, 18793, 18948, 20480, 20482, 20485, 20488, 20497,
+        20500, 20512, 20520, 20545, 20548, 20560, 20608, 20737, 20740, 20752, 20757, 20800, 20802, 20992, 21060, 21162,
+        21505, 21508, 21520, 21537, 21568, 21600, 21633, 21665, 21760, 21768, 21888, 21896, 22049, 22120, 22177, 22528,
+        22548, 22593, 22608, 22681, 22810, 22848, 22850, 23173, 24577, 24580, 24592, 24640, 24660, 24674, 24710, 24745,
+        24832, 25124, 25162, 25234, 25600, 25622, 25872, 25920, 25925, 26020, 26625, 26730, 26917, 27142, 27220, 27234,
+        32768, 32770, 32773, 32776, 32785, 32788, 32800, 32810, 32833, 32836, 32848, 32896, 32898, 32936, 32938, 33025,
+        33028, 33030, 33040, 33088, 33105, 33113, 33280, 33312, 33408, 33410, 33440, 33448, 33793, 33796, 33808, 33810,
+        33813, 33856, 33888, 33929, 34048, 34116, 34213, 34328, 34410, 34816, 34824, 34853, 34906, 34944, 34946, 34984,
+        35078, 35362, 35456, 35464, 35478, 35496, 36865, 36868, 36880, 36928, 36950, 36996, 37120, 37154, 37220, 37462,
+        37513, 37888, 37893, 37956, 37968, 37976, 38185, 38288, 38290, 38465, 38993, 39078, 39241, 39445, 39520, 40960,
+        40962, 40968, 40970, 40992, 41002, 41120, 41297, 41305, 41382, 41472, 41474, 41480, 41514, 41600, 41632, 42048,
+        42133, 42597, 42648, 43018, 43040, 43042, 43048, 43168, 43176, 43268, 43396, 43398, 43560, 43562, 43665, 43690,
+    };
+    static const uint16_t kgrid_1bit_2048[NGRID_IQ1S] = {
+            0,     2,     5,     8,    10,    17,    21,    32,    34,    40,    42,    69,    81,    84,    86,   101,
+          128,   130,   136,   138,   149,   160,   162,   168,   170,   260,   261,   273,   276,   278,   281,   282,
+          293,   321,   326,   329,   338,   341,   346,   353,   356,   358,   360,   389,   401,   404,   406,   421,
+          512,   514,   520,   522,   533,   544,   546,   552,   554,   581,   593,   601,   612,   617,   640,   642,
+          648,   650,   657,   661,   665,   672,   674,   680,   682,  1041,  1044,  1046,  1061,  1089,  1097,  1109,
+         1114,  1124,  1125,  1169,  1177,  1189,  1281,  1284,  1285,  1286,  1301,  1304,  1306,  1321,  1344,  1349,
+         1354,  1360,  1361,  1364,  1365,  1366,  1369,  1376,  1378,  1381,  1384,  1386,  1409,  1425,  1429,  1432,
+         1434,  1441,  1444,  1445,  1446,  1449,  1556,  1561,  1601,  1604,  1616,  1618,  1621,  1624,  1632,  1633,
+         1638,  1641,  1669,  1681,  1684,  1689,  2048,  2050,  2056,  2058,  2069,  2080,  2082,  2088,  2090,  2117,
+         2129,  2134,  2149,  2176,  2178,  2184,  2186,  2197,  2208,  2210,  2216,  2218,  2309,  2321,  2324,  2329,
+         2340,  2341,  2369,  2384,  2385,  2389,  2401,  2404,  2409,  2449,  2452,  2454,  2457,  2469,  2560,  2562,
+         2568,  2570,  2581,  2592,  2594,  2600,  2602,  2629,  2641,  2649,  2657,  2661,  2688,  2690,  2693,  2696,
+         2698,  2709,  2720,  2722,  2728,  2730,  4112,  4113,  4116,  4121,  4132,  4133,  4161,  4164,  4176,  4181,
+         4184,  4193,  4196,  4197,  4201,  4241,  4244,  4246,  4257,  4261,  4353,  4356,  4358,  4361,  4368,  4370,
+         4373,  4376,  4385,  4388,  4393,  4421,  4426,  4432,  4433,  4434,  4436,  4437,  4438,  4441,  4448,  4453,
+         4484,  4498,  4501,  4513,  4516,  4625,  4628,  4630,  4645,  4672,  4678,  4681,  4690,  4693,  4696,  4698,
+         4708,  4710,  4741,  4753,  4756,  4758,  4773,  5121,  5126,  5129,  5140,  5141,  5144,  5145,  5153,  5158,
+         5185,  5189,  5190,  5192,  5194,  5201,  5204,  5205,  5206,  5209,  5218,  5221,  5224,  5252,  5257,  5264,
+         5268,  5269,  5272,  5273,  5274,  5281,  5284,  5285,  5289,  5378,  5381,  5386,  5393,  5396,  5397,  5398,
+         5401,  5408,  5410,  5413,  5416,  5418,  5441,  5444,  5445,  5446,  5457,  5458,  5460,  5461,  5462,  5465,
+         5466,  5473,  5476,  5477,  5478,  5481,  5504,  5506,  5508,  5509,  5512,  5514,  5520,  5521,  5524,  5525,
+         5526,  5529,  5530,  5536,  5538,  5541,  5633,  5636,  5637,  5638,  5653,  5654,  5656,  5658,  5665,  5670,
+         5696,  5698,  5700,  5701,  5704,  5706,  5713,  5717,  5718,  5720,  5721,  5729,  5732,  5733,  5736,  5737,
+         5738,  5766,  5770,  5778,  5781,  5796,  5801,  6161,  6166,  6181,  6209,  6212,  6214,  6217,  6224,  6229,
+         6232,  6234,  6240,  6241,  6244,  6246,  6249,  6277,  6289,  6292,  6309,  6416,  6418,  6421,  6426,  6433,
+         6437,  6466,  6468,  6469,  6472,  6481,  6484,  6485,  6486,  6489,  6490,  6496,  6501,  6506,  6537,  6545,
+         6546,  6549,  6552,  6561,  6566,  6569,  6665,  6678,  6692,  6694,  6724,  6726,  6729,  6736,  6738,  6741,
+         6744,  6753,  6758,  6761,  6789,  6801,  6806,  6810,  8192,  8194,  8200,  8202,  8213,  8224,  8226,  8229,
+         8232,  8234,  8261,  8273,  8281,  8289,  8293,  8320,  8322,  8328,  8330,  8341,  8352,  8354,  8357,  8360,
+         8362,  8453,  8465,  8468,  8473,  8485,  8514,  8516,  8521,  8533,  8536,  8538,  8545,  8548,  8549,  8550,
+         8581,  8592,  8598,  8601,  8613,  8705,  8712,  8714,  8721,  8725,  8736,  8738,  8744,  8746,  8773,  8785,
+         8790,  8793,  8805,  8833,  8840,  8842,  8849,  8853,  8864,  8866,  8872,  8874,  9221,  9236,  9238,  9241,
+         9253,  9284,  9285,  9286,  9289,  9298,  9301,  9304,  9306,  9318,  9349,  9361,  9364,  9369,  9377,  9381,
+         9481,  9493,  9505,  9513,  9536,  9541,  9544,  9553,  9556,  9557,  9561,  9570,  9573,  9576,  9609,  9616,
+         9620,  9621,  9624,  9626,  9633,  9636,  9638,  9641,  9733,  9744,  9746,  9753,  9765,  9793,  9801,  9813,
+         9824,  9825,  9833,  9860,  9862,  9872,  9882, 10240, 10242, 10248, 10250, 10261, 10272, 10274, 10280, 10282,
+        10309, 10321, 10324, 10341, 10368, 10370, 10376, 10378, 10400, 10402, 10408, 10410, 10505, 10513, 10516, 10521,
+        10533, 10566, 10569, 10578, 10581, 10593, 10596, 10598, 10601, 10629, 10640, 10646, 10649, 10660, 10661, 10752,
+        10754, 10760, 10762, 10784, 10786, 10792, 10794, 10821, 10833, 10838, 10841, 10853, 10880, 10882, 10888, 10890,
+        10901, 10912, 10914, 10920, 10922, 16389, 16401, 16406, 16421, 16457, 16466, 16469, 16472, 16474, 16481, 16484,
+        16486, 16532, 16537, 16545, 16550, 16640, 16641, 16644, 16646, 16649, 16658, 16661, 16662, 16664, 16666, 16673,
+        16678, 16681, 16709, 16712, 16714, 16721, 16724, 16725, 16726, 16729, 16730, 16741, 16744, 16746, 16769, 16772,
+        16774, 16784, 16786, 16789, 16800, 16801, 16802, 16901, 16913, 16916, 16918, 16933, 16961, 16978, 16981, 16986,
+        16996, 17001, 17033, 17044, 17061, 17409, 17429, 17433, 17449, 17477, 17480, 17482, 17489, 17492, 17493, 17494,
+        17505, 17506, 17509, 17512, 17514, 17537, 17542, 17545, 17552, 17554, 17557, 17568, 17569, 17577, 17665, 17666,
+        17669, 17674, 17681, 17684, 17685, 17686, 17689, 17696, 17701, 17706, 17729, 17732, 17733, 17734, 17737, 17744,
+        17745, 17748, 17749, 17750, 17752, 17753, 17761, 17764, 17765, 17766, 17769, 17794, 17796, 17797, 17800, 17809,
+        17812, 17813, 17814, 17817, 17818, 17829, 17832, 17834, 17921, 17925, 17929, 17940, 17941, 17944, 17946, 17953,
+        17956, 17961, 17984, 17986, 17989, 17992, 18000, 18001, 18002, 18005, 18006, 18009, 18018, 18021, 18024, 18049,
+        18053, 18058, 18068, 18069, 18081, 18084, 18086, 18437, 18449, 18453, 18458, 18469, 18498, 18505, 18512, 18517,
+        18520, 18529, 18532, 18534, 18537, 18565, 18577, 18580, 18582, 18585, 18597, 18689, 18693, 18694, 18698, 18704,
+        18708, 18709, 18712, 18721, 18724, 18726, 18752, 18757, 18762, 18769, 18770, 18772, 18773, 18774, 18777, 18784,
+        18786, 18789, 18790, 18794, 18822, 18825, 18834, 18837, 18838, 18840, 18849, 18852, 18854, 18857, 18966, 19012,
+        19014, 19017, 19029, 19032, 19034, 19044, 19049, 19092, 19109, 20481, 20484, 20485, 20486, 20489, 20498, 20501,
+        20506, 20513, 20516, 20521, 20544, 20549, 20552, 20561, 20564, 20565, 20566, 20569, 20581, 20584, 20614, 20617,
+        20629, 20632, 20640, 20641, 20646, 20649, 20741, 20744, 20745, 20746, 20753, 20756, 20757, 20758, 20760, 20761,
+        20768, 20773, 20774, 20776, 20778, 20801, 20804, 20805, 20806, 20809, 20816, 20817, 20818, 20820, 20821, 20822,
+        20824, 20825, 20826, 20833, 20836, 20837, 20838, 20841, 20866, 20869, 20881, 20884, 20885, 20886, 20889, 20896,
+        20901, 20906, 20993, 20998, 21010, 21013, 21018, 21025, 21028, 21058, 21061, 21066, 21073, 21076, 21077, 21078,
+        21081, 21090, 21093, 21125, 21136, 21138, 21141, 21145, 21146, 21156, 21508, 21509, 21521, 21524, 21525, 21526,
+        21528, 21529, 21537, 21541, 21544, 21546, 21569, 21572, 21573, 21574, 21577, 21578, 21584, 21585, 21588, 21589,
+        21590, 21592, 21593, 21594, 21601, 21602, 21604, 21605, 21606, 21609, 21632, 21640, 21642, 21649, 21652, 21653,
+        21654, 21657, 21665, 21668, 21669, 21674, 21761, 21762, 21764, 21765, 21766, 21769, 21776, 21777, 21778, 21780,
+        21781, 21782, 21785, 21786, 21793, 21796, 21797, 21798, 21801, 21824, 21825, 21826, 21828, 21829, 21830, 21832,
+        21833, 21840, 21841, 21842, 21844, 21845, 21846, 21848, 21849, 21850, 21856, 21857, 21860, 21861, 21862, 21864,
+        21865, 21866, 21889, 21892, 21893, 21897, 21898, 21904, 21905, 21908, 21909, 21910, 21912, 21913, 21921, 21924,
+        21925, 21926, 21929, 22016, 22017, 22018, 22020, 22022, 22024, 22025, 22033, 22036, 22037, 22040, 22041, 22048,
+        22049, 22050, 22052, 22053, 22054, 22056, 22057, 22081, 22085, 22086, 22088, 22089, 22090, 22096, 22097, 22098,
+        22100, 22101, 22102, 22104, 22105, 22106, 22113, 22116, 22117, 22121, 22146, 22149, 22150, 22152, 22153, 22154,
+        22161, 22165, 22170, 22178, 22181, 22182, 22184, 22185, 22532, 22533, 22534, 22537, 22544, 22549, 22552, 22561,
+        22570, 22597, 22600, 22602, 22609, 22612, 22613, 22614, 22616, 22617, 22624, 22626, 22628, 22629, 22658, 22665,
+        22672, 22674, 22677, 22680, 22689, 22697, 22785, 22786, 22789, 22794, 22801, 22804, 22805, 22806, 22809, 22821,
+        22849, 22852, 22853, 22854, 22857, 22864, 22865, 22866, 22868, 22869, 22870, 22872, 22873, 22874, 22881, 22884,
+        22885, 22886, 22889, 22913, 22917, 22921, 22929, 22932, 22933, 22934, 22936, 22937, 22949, 23044, 23048, 23061,
+        23066, 23072, 23077, 23078, 23081, 23109, 23112, 23113, 23121, 23125, 23126, 23128, 23129, 23138, 23141, 23144,
+        23146, 23169, 23178, 23186, 23189, 23190, 23192, 23194, 23201, 24581, 24596, 24598, 24601, 24613, 24644, 24656,
+        24661, 24662, 24664, 24666, 24673, 24676, 24678, 24681, 24705, 24726, 24741, 24833, 24836, 24838, 24841, 24850,
+        24853, 24865, 24866, 24870, 24873, 24901, 24905, 24913, 24917, 24918, 24921, 24933, 24934, 24938, 24964, 24970,
+        24978, 24981, 24993, 24998, 25001, 25105, 25110, 25113, 25152, 25153, 25158, 25173, 25174, 25176, 25184, 25221,
+        25233, 25238, 25253, 25617, 25618, 25621, 25622, 25626, 25633, 25638, 25641, 25664, 25666, 25669, 25672, 25674,
+        25681, 25684, 25685, 25686, 25689, 25690, 25696, 25698, 25701, 25732, 25733, 25737, 25744, 25746, 25748, 25749,
+        25750, 25752, 25754, 25761, 25764, 25769, 25861, 25864, 25866, 25873, 25877, 25878, 25881, 25924, 25925, 25926,
+        25929, 25936, 25937, 25940, 25941, 25942, 25945, 25953, 25956, 25957, 25958, 25961, 25990, 25993, 25994, 26001,
+        26005, 26006, 26009, 26010, 26018, 26021, 26022, 26024, 26114, 26121, 26133, 26144, 26150, 26152, 26153, 26176,
+        26181, 26184, 26186, 26193, 26196, 26197, 26198, 26200, 26202, 26208, 26213, 26216, 26240, 26242, 26245, 26250,
+        26260, 26262, 26264, 26265, 26272, 26276, 26278, 26282, 26646, 26649, 26661, 26689, 26706, 26709, 26714, 26721,
+        26729, 26757, 26769, 26776, 26790, 26881, 26884, 26896, 26901, 26913, 26916, 26918, 26921, 26944, 26945, 26949,
+        26950, 26952, 26961, 26964, 26965, 26966, 26969, 26976, 26981, 26986, 27010, 27012, 27018, 27029, 27041, 27044,
+        27045, 27049, 27153, 27158, 27160, 27201, 27204, 27209, 27216, 27221, 27224, 27226, 27236, 27237, 27241, 27270,
+        27284, 27288, 27290, 27302, 32768, 32770, 32776, 32778, 32800, 32802, 32808, 32810, 32837, 32848, 32849, 32852,
+        32854, 32857, 32869, 32896, 32898, 32904, 32906, 32917, 32928, 32930, 32936, 32938, 33029, 33041, 33044, 33046,
+        33049, 33061, 33089, 33092, 33097, 33104, 33106, 33109, 33110, 33112, 33113, 33124, 33126, 33129, 33157, 33161,
+        33172, 33174, 33177, 33189, 33280, 33282, 33288, 33290, 33301, 33312, 33314, 33320, 33322, 33361, 33364, 33369,
+        33381, 33408, 33410, 33416, 33418, 33429, 33440, 33442, 33448, 33450, 33812, 33817, 33857, 33860, 33873, 33877,
+        33882, 33889, 33892, 33897, 33940, 33945, 34049, 34057, 34066, 34069, 34074, 34086, 34089, 34112, 34113, 34117,
+        34120, 34129, 34132, 34133, 34134, 34137, 34138, 34149, 34150, 34152, 34154, 34177, 34180, 34182, 34185, 34192,
+        34194, 34197, 34200, 34214, 34321, 34326, 34329, 34341, 34369, 34372, 34377, 34378, 34384, 34389, 34393, 34394,
+        34401, 34406, 34410, 34437, 34449, 34458, 34468, 34816, 34818, 34824, 34826, 34837, 34848, 34850, 34856, 34858,
+        34881, 34885, 34897, 34900, 34905, 34917, 34921, 34944, 34946, 34952, 34954, 34965, 34976, 34978, 34984, 34986,
+        35077, 35078, 35089, 35092, 35094, 35109, 35137, 35140, 35142, 35145, 35152, 35154, 35157, 35162, 35169, 35172,
+        35205, 35222, 35225, 35237, 35328, 35330, 35336, 35338, 35349, 35360, 35362, 35368, 35370, 35397, 35409, 35412,
+        35414, 35456, 35458, 35464, 35466, 35477, 35488, 35490, 35496, 35498, 36869, 36881, 36886, 36888, 36889, 36901,
+        36929, 36934, 36937, 36949, 36952, 36954, 36969, 36970, 36997, 37009, 37012, 37014, 37017, 37029, 37121, 37124,
+        37126, 37129, 37136, 37141, 37144, 37146, 37153, 37156, 37158, 37161, 37184, 37189, 37200, 37201, 37204, 37205,
+        37206, 37209, 37218, 37221, 37252, 37254, 37266, 37269, 37272, 37281, 37284, 37286, 37289, 37381, 37393, 37396,
+        37401, 37413, 37444, 37446, 37449, 37456, 37458, 37461, 37464, 37478, 37481, 37509, 37524, 37526, 37545, 37889,
+        37892, 37894, 37904, 37909, 37912, 37926, 37952, 37962, 37969, 37972, 37973, 37974, 37976, 37977, 37984, 37985,
+        37986, 37989, 38020, 38022, 38034, 38036, 38037, 38040, 38049, 38057, 38144, 38149, 38152, 38154, 38160, 38161,
+        38164, 38165, 38166, 38169, 38177, 38181, 38185, 38186, 38209, 38212, 38213, 38214, 38217, 38224, 38225, 38226,
+        38228, 38229, 38230, 38232, 38233, 38234, 38241, 38244, 38245, 38246, 38249, 38273, 38277, 38280, 38289, 38290,
+        38292, 38293, 38294, 38297, 38298, 38304, 38306, 38309, 38312, 38314, 38401, 38404, 38416, 38421, 38425, 38432,
+        38438, 38441, 38469, 38472, 38473, 38481, 38482, 38485, 38486, 38489, 38501, 38504, 38530, 38532, 38537, 38538,
+        38546, 38548, 38549, 38564, 38566, 38569, 38917, 38934, 38937, 38949, 38977, 38982, 38992, 38994, 38997, 38998,
+        39002, 39012, 39013, 39045, 39057, 39062, 39065, 39077, 39172, 39174, 39177, 39184, 39186, 39189, 39192, 39194,
+        39200, 39201, 39204, 39206, 39232, 39234, 39237, 39240, 39242, 39249, 39252, 39253, 39254, 39257, 39266, 39269,
+        39270, 39274, 39297, 39300, 39312, 39314, 39317, 39322, 39329, 39334, 39429, 39445, 39461, 39492, 39494, 39497,
+        39504, 39509, 39512, 39521, 39557, 39569, 39572, 39573, 39574, 40960, 40962, 40968, 40970, 40981, 40992, 40994,
+        41000, 41002, 41029, 41041, 41044, 41046, 41049, 41088, 41090, 41096, 41098, 41109, 41120, 41122, 41128, 41130,
+        41221, 41225, 41233, 41236, 41238, 41241, 41242, 41286, 41289, 41297, 41301, 41304, 41306, 41313, 41316, 41349,
+        41360, 41362, 41366, 41369, 41474, 41480, 41482, 41488, 41497, 41506, 41512, 41514, 41541, 41553, 41558, 41561,
+        41573, 41600, 41602, 41608, 41610, 41621, 41632, 41634, 41640, 41642, 42009, 42021, 42049, 42052, 42064, 42068,
+        42069, 42072, 42074, 42081, 42085, 42086, 42088, 42089, 42117, 42246, 42249, 42256, 42258, 42261, 42264, 42278,
+        42281, 42306, 42309, 42321, 42324, 42325, 42326, 42329, 42341, 42346, 42369, 42372, 42373, 42374, 42377, 42386,
+        42389, 42392, 42501, 42513, 42518, 42522, 42529, 42533, 42564, 42566, 42570, 42578, 42581, 42582, 42584, 42592,
+        42594, 42630, 42640, 42645, 42646, 42649, 42657, 42660, 42662, 43008, 43010, 43016, 43018, 43040, 43042, 43048,
+        43050, 43089, 43092, 43094, 43097, 43136, 43138, 43144, 43146, 43157, 43168, 43170, 43176, 43178, 43269, 43284,
+        43289, 43297, 43301, 43329, 43344, 43349, 43354, 43361, 43366, 43369, 43408, 43414, 43520, 43522, 43528, 43530,
+        43552, 43554, 43560, 43562, 43601, 43604, 43606, 43648, 43650, 43656, 43658, 43669, 43680, 43682, 43688, 43690,
+    };
+    static const uint16_t kgrid_2bit_1024[1024] = {
+            0,     2,     5,     8,    10,    17,    20,    22,    25,    32,    34,    37,    40,    65,    68,    70,
+           73,    80,    82,    85,    88,    97,   100,   102,   105,   128,   130,   133,   136,   145,   148,   160,
+          165,   170,   257,   260,   262,   265,   272,   274,   277,   280,   289,   292,   320,   322,   325,   328,
+          337,   340,   342,   345,   352,   357,   360,   385,   388,   400,   402,   405,   417,   420,   512,   514,
+          517,   520,   529,   532,   544,   554,   577,   580,   582,   585,   592,   597,   640,   645,   650,   660,
+          674,  1025,  1028,  1030,  1033,  1040,  1042,  1045,  1048,  1057,  1060,  1062,  1065,  1088,  1090,  1093,
+         1096,  1098,  1105,  1108,  1110,  1113,  1120,  1122,  1125,  1153,  1156,  1158,  1161,  1168,  1173,  1176,
+         1185,  1188,  1280,  1282,  1285,  1288,  1290,  1297,  1300,  1302,  1305,  1312,  1317,  1320,  1345,  1348,
+         1350,  1353,  1360,  1362,  1365,  1368,  1377,  1380,  1408,  1410,  1413,  1416,  1425,  1428,  1440,  1537,
+         1540,  1542,  1545,  1552,  1557,  1600,  1605,  1608,  1617,  1620,  1632,  1665,  1668,  1680,  2048,  2050,
+         2053,  2056,  2065,  2068,  2070,  2073,  2080,  2085,  2090,  2113,  2116,  2118,  2121,  2128,  2130,  2133,
+         2136,  2145,  2148,  2176,  2181,  2196,  2218,  2305,  2308,  2320,  2322,  2325,  2328,  2337,  2368,  2373,
+         2376,  2385,  2388,  2400,  2433,  2448,  2560,  2577,  2580,  2594,  2600,  2602,  2640,  2713,  4097,  4100,
+         4102,  4105,  4112,  4114,  4117,  4120,  4129,  4132,  4134,  4160,  4162,  4165,  4168,  4177,  4180,  4182,
+         4185,  4192,  4194,  4197,  4200,  4225,  4228,  4230,  4240,  4245,  4248,  4257,  4260,  4352,  4354,  4357,
+         4360,  4362,  4369,  4372,  4374,  4377,  4384,  4386,  4389,  4392,  4417,  4420,  4422,  4425,  4432,  4434,
+         4437,  4440,  4449,  4452,  4480,  4482,  4485,  4488,  4497,  4500,  4609,  4612,  4617,  4624,  4629,  4641,
+         4644,  4672,  4677,  4689,  4692,  4737,  4740,  4752,  5120,  5122,  5125,  5128,  5137,  5140,  5142,  5145,
+         5152,  5157,  5160,  5185,  5188,  5190,  5193,  5200,  5202,  5205,  5208,  5217,  5220,  5248,  5250,  5253,
+         5256,  5265,  5268,  5280,  5377,  5380,  5382,  5385,  5392,  5394,  5397,  5400,  5409,  5412,  5440,  5442,
+         5445,  5448,  5457,  5460,  5472,  5505,  5508,  5520,  5632,  5637,  5640,  5649,  5652,  5664,  5697,  5700,
+         5712,  5760,  5802,  6145,  6148,  6150,  6153,  6160,  6165,  6168,  6177,  6208,  6210,  6213,  6216,  6225,
+         6228,  6240,  6273,  6276,  6400,  6402,  6405,  6408,  6417,  6420,  6432,  6465,  6468,  6480,  6505,  6562,
+         6660,  6672,  6720,  6742,  8192,  8194,  8197,  8200,  8209,  8212,  8214,  8217,  8224,  8229,  8234,  8257,
+         8260,  8272,  8274,  8277,  8292,  8320,  8330,  8340,  8362,  8449,  8452,  8464,  8466,  8469,  8481,  8512,
+         8514,  8517,  8529,  8532,  8544,  8577,  8580,  8592,  8704,  8714,  8738,  8744,  8746,  8772,  8784,  8840,
+         8842,  8872,  9217,  9220,  9222,  9225,  9232,  9237,  9240,  9249,  9252,  9280,  9282,  9285,  9288,  9297,
+         9300,  9312,  9345,  9348,  9360,  9472,  9477,  9480,  9489,  9492,  9504,  9537,  9540,  9552,  9574,  9600,
+         9729,  9732,  9744,  9792,  9817, 10240, 10245, 10257, 10260, 10305, 10308, 10320, 10378, 10410, 10497, 10500,
+        10512, 10645, 10762, 10786, 10852, 10888, 10890, 16385, 16388, 16390, 16393, 16400, 16402, 16405, 16408, 16410,
+        16417, 16420, 16422, 16448, 16450, 16453, 16456, 16458, 16465, 16468, 16470, 16473, 16480, 16482, 16485, 16513,
+        16516, 16528, 16533, 16536, 16545, 16548, 16640, 16642, 16645, 16648, 16657, 16660, 16662, 16665, 16672, 16674,
+        16677, 16705, 16708, 16710, 16713, 16720, 16722, 16725, 16728, 16737, 16740, 16768, 16770, 16773, 16776, 16785,
+        16788, 16800, 16897, 16900, 16912, 16914, 16917, 16920, 16932, 16960, 16965, 16968, 16977, 16980, 16992, 17025,
+        17028, 17408, 17410, 17413, 17416, 17418, 17425, 17428, 17430, 17433, 17440, 17442, 17445, 17448, 17473, 17476,
+        17478, 17481, 17488, 17490, 17493, 17496, 17505, 17508, 17536, 17538, 17541, 17544, 17553, 17556, 17568, 17665,
+        17668, 17670, 17673, 17680, 17682, 17685, 17688, 17697, 17700, 17728, 17730, 17733, 17736, 17745, 17748, 17760,
+        17770, 17793, 17796, 17808, 17920, 17922, 17925, 17928, 17937, 17940, 17952, 17985, 17988, 18000, 18048, 18085,
+        18433, 18436, 18441, 18448, 18450, 18453, 18456, 18465, 18468, 18496, 18498, 18501, 18504, 18513, 18516, 18528,
+        18564, 18576, 18688, 18690, 18693, 18696, 18705, 18708, 18720, 18753, 18756, 18768, 18816, 18838, 18945, 18948,
+        18960, 19008, 20480, 20482, 20485, 20488, 20497, 20500, 20502, 20505, 20512, 20514, 20517, 20520, 20545, 20548,
+        20550, 20553, 20560, 20562, 20565, 20568, 20577, 20580, 20608, 20610, 20613, 20616, 20625, 20628, 20737, 20740,
+        20742, 20745, 20752, 20754, 20757, 20760, 20769, 20772, 20800, 20802, 20805, 20808, 20817, 20820, 20832, 20865,
+        20868, 20880, 20992, 20997, 21000, 21009, 21012, 21024, 21057, 21060, 21072, 21097, 21120, 21505, 21508, 21510,
+        21513, 21520, 21522, 21525, 21528, 21537, 21540, 21568, 21570, 21573, 21576, 21585, 21588, 21600, 21633, 21636,
+        21648, 21760, 21762, 21765, 21768, 21777, 21780, 21792, 21825, 21828, 21840, 21888, 22017, 22020, 22032, 22054,
+        22080, 22528, 22530, 22533, 22536, 22545, 22548, 22560, 22593, 22596, 22608, 22618, 22656, 22785, 22788, 22800,
+        22848, 23040, 23065, 23173, 23208, 24577, 24580, 24582, 24592, 24594, 24597, 24600, 24609, 24612, 24640, 24645,
+        24648, 24657, 24660, 24672, 24708, 24720, 24832, 24834, 24837, 24840, 24849, 24852, 24864, 24897, 24900, 24912,
+        24960, 24985, 25092, 25104, 25152, 25174, 25249, 25600, 25605, 25608, 25617, 25620, 25632, 25665, 25668, 25680,
+        25728, 25857, 25860, 25872, 25920, 25930, 25960, 26002, 26112, 26260, 26625, 26628, 26640, 26725, 26776, 26880,
+        26922, 27202, 27297, 32768, 32770, 32773, 32776, 32785, 32788, 32793, 32800, 32805, 32833, 32836, 32848, 32850,
+        32853, 32856, 32865, 32896, 32901, 32913, 32916, 33025, 33028, 33033, 33040, 33042, 33045, 33048, 33057, 33060,
+        33088, 33090, 33093, 33096, 33105, 33108, 33153, 33156, 33168, 33193, 33280, 33285, 33290, 33297, 33300, 33345,
+        33348, 33360, 33793, 33796, 33798, 33801, 33808, 33810, 33813, 33816, 33825, 33856, 33858, 33861, 33864, 33873,
+        33876, 33888, 33921, 33924, 33936, 34048, 34050, 34053, 34056, 34065, 34068, 34080, 34113, 34116, 34128, 34176,
+        34186, 34305, 34308, 34320, 34345, 34368, 34816, 34821, 34833, 34836, 34881, 34884, 34896, 34978, 35073, 35076,
+        35136, 35173, 35362, 35416, 35418, 35458, 35490, 36865, 36868, 36873, 36880, 36882, 36885, 36888, 36900, 36928,
+        36930, 36933, 36936, 36945, 36948, 36960, 36993, 36996, 37008, 37120, 37125, 37137, 37140, 37185, 37188, 37200,
+        37210, 37377, 37380, 37392, 37440, 37542, 37888, 37890, 37893, 37896, 37905, 37908, 37920, 37953, 37956, 37968,
+        38016, 38038, 38145, 38148, 38160, 38208, 38296, 38305, 38400, 38470, 38500, 38913, 38916, 38928, 38950, 38976,
+        39081, 39168, 39241, 39250, 39568, 40960, 40965, 40970, 40980, 40994, 41002, 41025, 41028, 41040, 41122, 41130,
+        41280, 41317, 41474, 41482, 41506, 41512, 41514, 41602, 41608, 41610, 41640, 41985, 41988, 42000, 42048, 42121,
+        42148, 42240, 42265, 42577, 43018, 43048, 43170, 43348, 43398, 43528, 43530, 43552, 43554, 43560, 43656, 43690,
+    };
+
+    const int kmap_size = 43692;
+    //const int nwant = type == GGML_TYPE_IQ1_S ? 3 : 2;
+    const int nwant = type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? 3 : type == GGML_TYPE_IQ2_S ? 1 : 2;
+    const uint16_t * kgrid = type == GGML_TYPE_IQ2_XXS ? kgrid_2bit_256 :
+                             type == GGML_TYPE_IQ2_XS  ? kgrid_2bit_512 :
+                             type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M ? kgrid_1bit_2048 : kgrid_2bit_1024;
+    uint64_t * kgrid_q2xs;
+    int      * kmap_q2xs;
+    uint16_t * kneighbors_q2xs;
+
+    //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+    uint64_t * the_grid = (uint64_t *)malloc(grid_size*sizeof(uint64_t));
+    for (int k = 0; k < grid_size; ++k) {
+        int8_t * pos = (int8_t *)(the_grid + k);
+        for (int i = 0; i < 8; ++i) {
+            int l = (kgrid[k] >> 2*i) & 0x3;
+            pos[i] = 2*l + 1;
+        }
+    }
+    kgrid_q2xs = the_grid;
+    iq2_data[gindex].grid = the_grid;
+    kmap_q2xs = (int *)malloc(kmap_size*sizeof(int));
+    iq2_data[gindex].map = kmap_q2xs;
+    for (int i = 0; i < kmap_size; ++i) kmap_q2xs[i] = -1;
+    uint64_t aux64;
+    uint8_t * aux8 = (uint8_t *)&aux64;
+    for (int i = 0; i < grid_size; ++i) {
+        aux64 = kgrid_q2xs[i];
+        uint16_t index = 0;
+        for (int k=0; k<8; ++k) {
+            uint16_t q = (aux8[k] - 1)/2;
+            index |= (q << 2*k);
+        }
+        kmap_q2xs[index] = i;
+    }
+    int8_t pos[8];
+    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+    int num_neighbors = 0, num_not_in_map = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q2xs[i] >= 0) continue;
+        ++num_not_in_map;
+        for (int k = 0; k < 8; ++k) {
+            int l = (i >> 2*k) & 0x3;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+        int n = 0; int d2 = dist2[0];
+        int nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            ++n;
+        }
+        num_neighbors += n;
+    }
+    //printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+    kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+    iq2_data[gindex].neighbours = kneighbors_q2xs;
+    int counter = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q2xs[i] >= 0) continue;
+        for (int k = 0; k < 8; ++k) {
+            int l = (i >> 2*k) & 0x3;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q2xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func);
+        kmap_q2xs[i] = -(counter + 1);
+        int d2 = dist2[0];
+        uint16_t * start = &kneighbors_q2xs[counter++];
+        int n = 0, nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            kneighbors_q2xs[counter++] = dist2[2*j+1];
+            ++n;
+        }
+        *start = n;
+    }
+    free(dist2);
+}
+
+void iq2xs_free_impl(enum ggml_type type) {
+    GGML_ASSERT(type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ1_S || type == GGML_TYPE_IQ1_M || type == GGML_TYPE_IQ2_S);
+    const int gindex = iq2_data_index(type);
+    if (iq2_data[gindex].grid) {
+        free(iq2_data[gindex].grid);       iq2_data[gindex].grid = NULL;
+        free(iq2_data[gindex].map);        iq2_data[gindex].map  = NULL;
+        free(iq2_data[gindex].neighbours); iq2_data[gindex].neighbours = NULL;
+    }
+}
+
+static int iq2_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_d2 = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = pg[i];
+            float diff = scale*q - xval[i];
+            d2 += weight[i]*diff*diff;
+        }
+        if (d2 < best_d2) {
+            best_d2 = d2; grid_index = neighbours[j];
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XXS);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int64_t nbl = n/QK_K;
+
+    block_iq2_xxs * y = vy;
+
+    float scales[QK_K/32];
+    float weight[32];
+    float xval[32];
+    int8_t L[32];
+    int8_t Laux[32];
+    float  waux[32];
+    uint8_t block_signs[4];
+    uint32_t q2[2*(QK_K/32)];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q2, 0, QK_K/4);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float * xb = xbl + 32*ib;
+            const float * qw = quant_weights + QK_K*ibl + 32*ib;
+            for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 4; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+            if (max < GROUP_MAX_EPS) {
+                scales[ib] = 0;
+                memset(L, 0, 32);
+                continue;
+            }
+            float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
+            float eff_max = scale*kMaxQ;
+            float best = 0;
+            for (int is = -6; is <= 6; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/eff_max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 4; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    memcpy(L, Laux, 32);
+                }
+            }
+            if (scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 4; ++k) {
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + grid_index);
+                    for (int i = 0; i < 8; ++i) L[8*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 4; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ABORT("fatal error");
+                }
+                q2[2*ib+0] |= ((uint32_t) grid_index << 8*k);
+                q2[2*ib+1] |= (block_signs[k] << 7*k);
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/4);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            q2[2*ib+1] |= ((uint32_t)l << 28);
+        }
+        memcpy(y[ibl].qs, q2, QK_K/4);
+    }
+}
+
+static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ2_XS);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int64_t nbl = n/QK_K;
+
+    block_iq2_xs * y = vy;
+
+    float scales[QK_K/16];
+    float weight[16];
+    float xval[16];
+    int8_t L[16];
+    int8_t Laux[16];
+    float  waux[16];
+    bool   is_on_grid[2];
+    bool   is_on_grid_aux[2];
+    uint8_t block_signs[2];
+    uint16_t q2[2*(QK_K/16)];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(q2, 0, QK_K/4);
+        memset(y[ibl].scales, 0, QK_K/32);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            const float * xb = xbl + 16*ib;
+            const float * qw = quant_weights + QK_K*ibl + 16*ib;
+            for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 2; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+            if (max < GROUP_MAX_EPS) {
+                scales[ib] = 0;
+                memset(L, 0, 16);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            is_on_grid[0] = is_on_grid[1] = true;
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 2; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 2; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                        L[8*k + i] = l;
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                scale = -scale;
+                for (int k = 0; k < 2; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 2; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ABORT("fatal error");
+                }
+                q2[2*ib+k] = grid_index | (block_signs[k] << 9);
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(y[ibl].qs, 0, QK_K/4);
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            if (ib%2 == 0) y[ibl].scales[ib/2] = l;
+            else y[ibl].scales[ib/2] |= (l << 4);
+        }
+        memcpy(y[ibl].qs, q2, QK_K/4);
+
+    }
+}
+
+size_t quantize_iq2_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq2_xxs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_xxs);
+    }
+    return nrow * nblock * sizeof(block_iq2_xxs);
+}
+
+size_t quantize_iq2_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq2_xs_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_xs);
+    }
+    return nrow * nblock * sizeof(block_iq2_xs);
+}
+
+//
+// ============================================= 3-bit using D4 lattice
+//
+
+typedef struct {
+    uint32_t * grid;
+    int      * map;
+    uint16_t * neighbours;
+} iq3_entry_t;
+
+static iq3_entry_t iq3_data[2] = {
+    {NULL, NULL, NULL},
+    {NULL, NULL, NULL},
+};
+
+static inline int iq3_data_index(int grid_size) {
+    (void)grid_size;
+    GGML_ASSERT(grid_size == 256 || grid_size == 512);
+    return grid_size == 256 ? 0 : 1;
+}
+
+static int iq3_compare_func(const void * left, const void * right) {
+    const int * l = (const int *)left;
+    const int * r = (const int *)right;
+    return l[0] < r[0] ? -1 : l[0] > r[0] ? 1 : l[1] < r[1] ? -1 : l[1] > r[1] ? 1 : 0;
+}
+
+void iq3xs_init_impl(int grid_size) {
+    const int gindex = iq3_data_index(grid_size);
+    if (iq3_data[gindex].grid) {
+        return;
+    }
+    static const uint16_t kgrid_256[256] = {
+            0,     2,     4,     9,    11,    15,    16,    18,    25,    34,    59,    61,    65,    67,    72,    74,
+           81,    85,    88,    90,    97,   108,   120,   128,   130,   132,   137,   144,   146,   153,   155,   159,
+          169,   175,   189,   193,   199,   200,   202,   213,   248,   267,   287,   292,   303,   315,   317,   321,
+          327,   346,   362,   413,   436,   456,   460,   462,   483,   497,   513,   515,   520,   522,   529,   531,
+          536,   538,   540,   551,   552,   576,   578,   585,   592,   594,   641,   643,   648,   650,   657,   664,
+          698,   704,   706,   720,   729,   742,   758,   769,   773,   808,   848,   852,   870,   889,   901,   978,
+          992,  1024,  1026,  1033,  1035,  1040,  1042,  1046,  1049,  1058,  1089,  1091,  1093,  1096,  1098,  1105,
+         1112,  1139,  1143,  1144,  1152,  1154,  1161,  1167,  1168,  1170,  1183,  1184,  1197,  1217,  1224,  1228,
+         1272,  1276,  1309,  1323,  1347,  1367,  1377,  1404,  1473,  1475,  1486,  1509,  1537,  1544,  1546,  1553,
+         1555,  1576,  1589,  1594,  1600,  1602,  1616,  1625,  1636,  1638,  1665,  1667,  1672,  1685,  1706,  1722,
+         1737,  1755,  1816,  1831,  1850,  1856,  1862,  1874,  1901,  1932,  1950,  1971,  2011,  2032,  2052,  2063,
+         2077,  2079,  2091,  2095,  2172,  2192,  2207,  2208,  2224,  2230,  2247,  2277,  2308,  2345,  2356,  2389,
+         2403,  2424,  2501,  2504,  2506,  2520,  2570,  2593,  2616,  2624,  2630,  2646,  2669,  2700,  2714,  2746,
+         2754,  2795,  2824,  2835,  2839,  2874,  2882,  2905,  2984,  3028,  3042,  3092,  3108,  3110,  3124,  3153,
+         3185,  3215,  3252,  3288,  3294,  3364,  3397,  3434,  3483,  3523,  3537,  3587,  3589,  3591,  3592,  3610,
+         3626,  3670,  3680,  3722,  3749,  3754,  3776,  3789,  3803,  3824,  3857,  3873,  3904,  3906,  3924,  3992,
+    };
+    static const uint16_t kgrid_512[512] = {
+            0,     1,     2,     5,     7,     8,     9,    10,    12,    14,    16,    17,    21,    27,    32,    34,
+           37,    39,    41,    43,    48,    50,    57,    60,    63,    64,    65,    66,    68,    72,    73,    77,
+           80,    83,    87,    89,    93,   100,   113,   117,   122,   128,   129,   133,   135,   136,   139,   142,
+          145,   149,   152,   156,   162,   165,   167,   169,   171,   184,   187,   195,   201,   205,   208,   210,
+          217,   219,   222,   228,   232,   234,   247,   249,   253,   256,   267,   271,   273,   276,   282,   288,
+          291,   297,   312,   322,   324,   336,   338,   342,   347,   353,   357,   359,   374,   379,   390,   393,
+          395,   409,   426,   441,   448,   450,   452,   464,   466,   470,   475,   488,   492,   512,   513,   514,
+          516,   520,   521,   523,   525,   527,   528,   530,   537,   540,   542,   556,   558,   561,   570,   576,
+          577,   579,   582,   584,   588,   593,   600,   603,   609,   616,   618,   632,   638,   640,   650,   653,
+          655,   656,   660,   666,   672,   675,   685,   688,   698,   705,   708,   711,   712,   715,   721,   727,
+          728,   732,   737,   754,   760,   771,   773,   778,   780,   793,   795,   802,   806,   808,   812,   833,
+          840,   843,   849,   856,   858,   873,   912,   916,   919,   932,   934,   961,   963,   968,   970,   977,
+          989,   993,  1010,  1016,  1024,  1025,  1027,  1029,  1031,  1032,  1034,  1036,  1038,  1041,  1043,  1047,
+         1048,  1050,  1057,  1059,  1061,  1064,  1066,  1079,  1080,  1083,  1085,  1088,  1090,  1096,  1099,  1103,
+         1106,  1109,  1113,  1116,  1122,  1129,  1153,  1156,  1159,  1169,  1171,  1176,  1183,  1185,  1195,  1199,
+         1209,  1212,  1216,  1218,  1221,  1225,  1234,  1236,  1241,  1243,  1250,  1256,  1270,  1281,  1287,  1296,
+         1299,  1306,  1309,  1313,  1338,  1341,  1348,  1353,  1362,  1375,  1376,  1387,  1400,  1408,  1410,  1415,
+         1425,  1453,  1457,  1477,  1481,  1494,  1496,  1507,  1512,  1538,  1545,  1547,  1549,  1551,  1554,  1561,
+         1563,  1565,  1570,  1572,  1575,  1577,  1587,  1593,  1601,  1603,  1605,  1612,  1617,  1619,  1632,  1648,
+         1658,  1662,  1664,  1674,  1680,  1690,  1692,  1704,  1729,  1736,  1740,  1745,  1747,  1751,  1752,  1761,
+         1763,  1767,  1773,  1787,  1795,  1801,  1806,  1810,  1817,  1834,  1840,  1844,  1857,  1864,  1866,  1877,
+         1882,  1892,  1902,  1915,  1934,  1953,  1985,  1987,  2000,  2002,  2013,  2048,  2052,  2058,  2064,  2068,
+         2071,  2074,  2081,  2088,  2104,  2114,  2119,  2121,  2123,  2130,  2136,  2141,  2147,  2153,  2157,  2177,
+         2179,  2184,  2189,  2193,  2203,  2208,  2223,  2226,  2232,  2244,  2249,  2251,  2256,  2258,  2265,  2269,
+         2304,  2306,  2324,  2335,  2336,  2361,  2373,  2375,  2385,  2418,  2443,  2460,  2480,  2504,  2509,  2520,
+         2531,  2537,  2562,  2568,  2572,  2578,  2592,  2596,  2599,  2602,  2614,  2620,  2625,  2627,  2629,  2634,
+         2641,  2650,  2682,  2688,  2697,  2707,  2712,  2718,  2731,  2754,  2759,  2760,  2775,  2788,  2793,  2805,
+         2811,  2817,  2820,  2832,  2842,  2854,  2890,  2902,  2921,  2923,  2978,  3010,  3012,  3026,  3081,  3083,
+         3085,  3097,  3099,  3120,  3136,  3152,  3159,  3188,  3210,  3228,  3234,  3245,  3250,  3256,  3264,  3276,
+         3281,  3296,  3349,  3363,  3378,  3392,  3395,  3420,  3440,  3461,  3488,  3529,  3531,  3584,  3588,  3591,
+         3600,  3602,  3614,  3616,  3628,  3634,  3650,  3657,  3668,  3683,  3685,  3713,  3716,  3720,  3726,  3729,
+         3736,  3753,  3778,  3802,  3805,  3819,  3841,  3845,  3851,  3856,  3880,  3922,  3938,  3970,  3993,  4032,
+    };
+
+    const int kmap_size = 4096;
+    const int nwant = grid_size == 256 ? 2 : 3;
+    const uint16_t * kgrid = grid_size == 256 ? kgrid_256 : kgrid_512;
+    uint32_t * kgrid_q3xs;
+    int      * kmap_q3xs;
+    uint16_t * kneighbors_q3xs;
+
+    //printf("================================================================= %s(grid_size = %d)\n", __func__, grid_size);
+    uint32_t * the_grid = (uint32_t *)malloc(grid_size*sizeof(uint32_t));
+    for (int k = 0; k < grid_size; ++k) {
+        int8_t * pos = (int8_t *)(the_grid + k);
+        for (int i = 0; i < 4; ++i) {
+            int l = (kgrid[k] >> 3*i) & 0x7;
+            pos[i] = 2*l + 1;
+        }
+    }
+    kgrid_q3xs = the_grid;
+    iq3_data[gindex].grid = the_grid;
+    kmap_q3xs = (int *)malloc(kmap_size*sizeof(int));
+    iq3_data[gindex].map = kmap_q3xs;
+    for (int i = 0; i < kmap_size; ++i) kmap_q3xs[i] = -1;
+    uint32_t aux32;
+    uint8_t * aux8 = (uint8_t *)&aux32;
+    for (int i = 0; i < grid_size; ++i) {
+        aux32 = kgrid_q3xs[i];
+        uint16_t index = 0;
+        for (int k=0; k<4; ++k) {
+            uint16_t q = (aux8[k] - 1)/2;
+            index |= (q << 3*k);
+        }
+        kmap_q3xs[index] = i;
+    }
+    int8_t pos[4];
+    int * dist2 = (int *)malloc(2*grid_size*sizeof(int));
+    int num_neighbors = 0, num_not_in_map = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q3xs[i] >= 0) continue;
+        ++num_not_in_map;
+        for (int k = 0; k < 4; ++k) {
+            int l = (i >> 3*k) & 0x7;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+        int n = 0; int d2 = dist2[0];
+        int nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            ++n;
+        }
+        num_neighbors += n;
+    }
+    //printf("%s: %d neighbours in total\n", __func__, num_neighbors);
+    kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t));
+    iq3_data[gindex].neighbours = kneighbors_q3xs;
+    int counter = 0;
+    for (int i = 0; i < kmap_size; ++i) {
+        if (kmap_q3xs[i] >= 0) continue;
+        for (int k = 0; k < 4; ++k) {
+            int l = (i >> 3*k) & 0x7;
+            pos[k] = 2*l + 1;
+        }
+        for (int j = 0; j < grid_size; ++j) {
+            const int8_t * pg = (const int8_t *)(kgrid_q3xs + j);
+            int d2 = 0;
+            for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]);
+            dist2[2*j+0] = d2;
+            dist2[2*j+1] = j;
+        }
+        qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func);
+        kmap_q3xs[i] = -(counter + 1);
+        int d2 = dist2[0];
+        uint16_t * start = &kneighbors_q3xs[counter++];
+        int n = 0, nhave = 1;
+        for (int j = 0; j < grid_size; ++j) {
+            if (dist2[2*j] > d2) {
+                if (nhave == nwant) break;
+                d2 = dist2[2*j];
+                ++nhave;
+            }
+            kneighbors_q3xs[counter++] = dist2[2*j+1];
+            ++n;
+        }
+        *start = n;
+    }
+    free(dist2);
+}
+
+void iq3xs_free_impl(int grid_size) {
+    GGML_ASSERT(grid_size == 256 || grid_size == 512);
+    const int gindex = iq3_data_index(grid_size);
+    if (iq3_data[gindex].grid) {
+        free(iq3_data[gindex].grid);       iq3_data[gindex].grid = NULL;
+        free(iq3_data[gindex].map);        iq3_data[gindex].map  = NULL;
+        free(iq3_data[gindex].neighbours); iq3_data[gindex].neighbours = NULL;
+    }
+}
+
+static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const uint32_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_d2 = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 4; ++i) {
+            float q = pg[i];
+            float diff = scale*q - xval[i];
+            d2 += weight[i]*diff*diff;
+        }
+        if (d2 < best_d2) {
+            best_d2 = d2; grid_index = neighbours[j];
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 4; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int64_t n,
+        const float * restrict quant_weights) {
+
+    const int gindex = iq3_data_index(grid_size);
+
+    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
+    const int      * kmap_q3xs       = iq3_data[gindex].map;
+    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+    //GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q3xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q3xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 8;
+
+    const int64_t nbl = n/QK_K;
+
+    ggml_fp16_t * dh;
+    uint8_t * qs;
+    int block_size;
+    if (grid_size == 256) {
+        block_iq3_xxs * y = vy;
+        dh = &y->d;
+        qs = y->qs;
+        block_size = sizeof(block_iq3_xxs);
+    } else {
+        block_iq3_s * y = vy;
+        dh = &y->d;
+        qs = y->qs;
+        block_size = sizeof(block_iq3_s);
+    }
+    int quant_size = block_size - sizeof(ggml_fp16_t);
+
+    float scales[QK_K/32];
+    float weight[32];
+    float xval[32];
+    int8_t L[32];
+    int8_t Laux[32];
+    float  waux[32];
+    bool   is_on_grid[8];
+    bool   is_on_grid_aux[8];
+    uint8_t block_signs[8];
+    uint8_t q3[3*(QK_K/8)+QK_K/32];
+    uint32_t * scales_and_signs = (uint32_t *)(q3 + QK_K/4);
+    uint8_t  * qh = q3 + 3*(QK_K/8);
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        dh[0] = GGML_FP32_TO_FP16(0.f);
+        memset(q3, 0, 3*QK_K/8+QK_K/32);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            const float * xb = xbl + 32*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + 32*ib;
+                for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < 32; ++i) weight[i] = xb[i]*xb[i];
+            }
+            for (int i = 0; i < 32; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 4; ++k) {
+                int nflip = 0;
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; ++nflip; s |= (1 << i);
+                    }
+                }
+                if (nflip%2) {
+                    int imin = 0; float min = weight[8*k+imin]*xb[8*k+imin]*xb[8*k+imin];
+                    for (int i = 1; i < 8; ++i) {
+                        float ax = weight[8*k+i]*xb[8*k+i]*xb[8*k+i];
+                        if (ax < min) {
+                            min = ax; imin = i;
+                        }
+                    }
+                    xval[8*k+imin] = -xval[8*k+imin];
+                    s ^= (1 << imin);
+                }
+                block_signs[k] = s & 127;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]);
+            if (max < GROUP_MAX_EPS_IQ3_XXS) {
+                scales[ib] = 0;
+                memset(L, 0, 32);
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            for (int is = -15; is <= 15; ++is) {
+                float id = (2*kMaxQ-1+is*0.2f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 8; ++k) {
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+                    int grid_index = kmap_q3xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 32; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  8; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 8; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 8; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 3*i);
+                    }
+                    int grid_index = kmap_q3xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 32; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < 4; ++k) block_signs[k] = (~block_signs[k]) & 127;
+            }
+            for (int k = 0; k < 8; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+                int grid_index = kmap_q3xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+                    printf("\n");
+                    GGML_ABORT("fatal error");
+                }
+                if (grid_size == 256) {
+                    q3[8*ib+k] = grid_index;
+                } else {
+                    q3[8*ib+k] = grid_index & 255;
+                    qh[ib] |= ((grid_index >> 8) << k);
+                }
+
+            }
+            scales_and_signs[ib] = block_signs[0] | (block_signs[1] << 7) | (block_signs[2] << 14) | (block_signs[3] << 21);
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            memset(qs, 0, quant_size);
+            dh += block_size/sizeof(ggml_fp16_t);
+            qs += block_size;
+            continue;
+        }
+
+        float d = max_scale/31;
+        dh[0] = GGML_FP32_TO_FP16(d * 1.0125f);  // small improvement via this fudge factor
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/32; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            scales_and_signs[ib] |= ((uint32_t)l << 28);
+        }
+        memcpy(qs, q3, quant_size);
+
+        dh += block_size/sizeof(ggml_fp16_t);
+        qs += block_size;
+
+    }
+}
+
+size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq3_xxs);
+    }
+    return nrow * nblock * sizeof(block_iq3_xxs);
+}
+
+void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
+}
+
+static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, void * restrict vy, int n,
+        const float * restrict quant_weights,
+        float   * scales,
+        float   * weight,
+        float   * xval,
+        int8_t  * L,
+        int8_t  * Laux,
+        float   * waux,
+        bool    * is_on_grid,
+        bool    * is_on_grid_aux,
+        uint8_t * block_signs) {
+
+    const int gindex = iq3_data_index(512);
+
+    const uint32_t * kgrid_q3xs      = iq3_data[gindex].grid;
+    const int      * kmap_q3xs       = iq3_data[gindex].map;
+    const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours;
+
+    //GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q3xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q3xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 8;
+
+    const int64_t nbl = n/QK_K;
+
+    block_iq3_s * y = vy;
+
+    const int bs4 = block_size/4;
+    const int bs8 = block_size/8;
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        memset(&y[ibl], 0, sizeof(block_iq3_s));
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+        uint8_t * qs = y[ibl].qs;
+        uint8_t * qh = y[ibl].qh;
+        uint8_t * signs = y[ibl].signs;
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            const float * xb = xbl + block_size*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+            }
+            for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < bs8; ++k) {
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
+                    }
+                }
+                block_signs[k] = s;
+            }
+            float max = xval[0];
+            for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]);
+            if (!max) {
+                scales[ib] = 0;
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            for (int k = 0; k < bs4; ++k) is_on_grid[k] = false;
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.2f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < bs4; ++k) {
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i);
+                    int grid_index = kmap_q3xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < block_size; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < block_size; ++i) L[i] = Laux[i];
+                    for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < bs4; ++k) {
+                    //if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 4; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[4*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 3*i);
+                    }
+                    int grid_index = kmap_q3xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1;
+                        grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k);
+                    }
+                    const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index);
+                    for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2;
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < block_size; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale)
+                // and correspondingly flip quant signs.
+                scale = -scale;
+                for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k];
+            }
+            for (int k = 0; k < bs4; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i);
+                int grid_index = kmap_q3xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]);
+                    printf("\n");
+                    GGML_ABORT("fatal error");
+                }
+                qs[k] = grid_index & 255;
+                qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8));
+            }
+            qs += bs4;
+            for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k];
+            signs += bs8;
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d * 1.033f);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/block_size; ib += 2) {
+            int l1 = nearest_int(0.5f*(id*scales[ib+0]-1));
+            l1 = MAX(0, MIN(15, l1));
+            int l2 = nearest_int(0.5f*(id*scales[ib+1]-1));
+            l2 = MAX(0, MIN(15, l2));
+            y[ibl].scales[ib/2] = l1 | (l2 << 4);
+        }
+
+    }
+}
+
+#define IQ3S_BLOCK_SIZE 32
+size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int64_t nblock = n_per_row/QK_K;
+    float scales[QK_K/IQ3S_BLOCK_SIZE];
+    float weight[IQ3S_BLOCK_SIZE];
+    float xval[IQ3S_BLOCK_SIZE];
+    int8_t L[IQ3S_BLOCK_SIZE];
+    int8_t Laux[IQ3S_BLOCK_SIZE];
+    float  waux[IQ3S_BLOCK_SIZE];
+    bool   is_on_grid[IQ3S_BLOCK_SIZE/4];
+    bool   is_on_grid_aux[IQ3S_BLOCK_SIZE/4];
+    uint8_t block_signs[IQ3S_BLOCK_SIZE/8];
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq3_s_impl(IQ3S_BLOCK_SIZE, src, qrow, n_per_row, quant_weights,
+                scales, weight, xval, L, Laux, waux, is_on_grid, is_on_grid_aux, block_signs);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq3_s);
+    }
+    return nrow * nblock * sizeof(block_iq3_s);
+}
+
+void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    quantize_iq3_s(x, y, 1, k, NULL);
+}
+
+
+// =================================== 1.5 bpw ===================================================
+
+static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_score = -FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float sumqx = 0, sumq2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = (pg[i] - 3)/2;
+            float w = weight[i];
+            sumqx += w*q*xval[i];
+            sumq2 += w*q*q;
+        }
+        if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+            *scale = sumqx/sumq2; best_score = *scale * sumqx;
+            grid_index = neighbours[j];
+        }
+    }
+    if (grid_index < 0) {
+        for (int i = 0; i < ngrid; ++i) {
+            const int8_t * grid_i = (const int8_t *)(grid + i);
+            float sumqx = 0, sumq2 = 0;
+            for (int j = 0; j < 8; ++j) {
+                float w = weight[j];
+                float q = (grid_i[j] - 3)/2;
+                sumqx += w*q*xval[j];
+                sumq2 += w*q*q;
+            }
+            if (sumqx > 0 && sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+                *scale = sumqx/sumq2; best_score = *scale*sumqx;
+                grid_index = i;
+            }
+        }
+    }
+    if (grid_index < 0) {
+        printf("Oops, did not find grid point\n");
+        printf("Have %d neighbours\n", num_neighbors);
+        for (int j = 1; j <= num_neighbors; ++j) {
+            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+            float sumqx = 0, sumq2 = 0;
+            for (int i = 0; i < 8; ++i) {
+                float q = (pg[i] - 3)/2;
+                float w = weight[i];
+                sumqx += w*q*xval[i];
+                sumq2 += w*q*q;
+            }
+            printf("    neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+    *scale *= 1.05f;  // This is a fudge factor. Don't ask me why it improves the result.
+    //!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
+        const float * restrict xval, const float * restrict weight, float scale, const float * restrict xg, int8_t * restrict L, int ngrid) {
+    int num_neighbors = neighbours[0];
+    GGML_ASSERT(num_neighbors > 0);
+    float best_score = FLT_MAX;
+    int grid_index = -1;
+    for (int j = 1; j <= num_neighbors; ++j) {
+        const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+        float d2 = 0;
+        for (int i = 0; i < 8; ++i) {
+            float q = xg[(pg[i] - 1)/2];
+            float w = weight[i];
+            float diff = scale*q - xval[i];
+            d2 += w*diff*diff;
+        }
+        if (d2 < best_score) {
+            best_score = d2;
+            grid_index = neighbours[j];
+        }
+    }
+    if (grid_index < 0) {
+        for (int i = 0; i < ngrid; ++i) {
+            const int8_t * grid_i = (const int8_t *)(grid + i);
+            float d2 = 0;
+            for (int j = 0; j < 8; ++j) {
+                float w = weight[j];
+                float q = xg[(grid_i[j] - 1)/2];
+                float diff = scale*q - xval[i];
+                d2 += w*diff*diff;
+            }
+            if (d2 < best_score) {
+                best_score = d2;
+                grid_index = i;
+            }
+        }
+    }
+    if (grid_index < 0) {
+        printf("Oops, did not find grid point\n");
+        printf("Have %d neighbours\n", num_neighbors);
+        for (int j = 1; j <= num_neighbors; ++j) {
+            const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
+            float sumqx = 0, sumq2 = 0;
+            for (int i = 0; i < 8; ++i) {
+                float q = xg[(pg[i] - 1)/2];
+                float w = weight[i];
+                sumqx += w*q*xval[i];
+                sumq2 += w*q*q;
+            }
+            printf("    neighbour %d: sumqx = %g sumq2 = %g\n", j, (double)sumqx, (double)sumq2);
+        }
+    }
+    GGML_ASSERT(grid_index >= 0);
+    const int8_t * pg = (const int8_t *)(grid + grid_index);
+    for (int i = 0; i < 8; ++i) L[i] = (pg[i] - 1)/2;
+    return grid_index;
+}
+
+static int iq1_sort_helper(const void * left, const void * right) {
+    const float * l = left;
+    const float * r = right;
+    return *l < *r ? -1 : *l > *r ? 1 : 0;
+}
+
+#define IQ1S_BLOCK_SIZE 32
+#define IQ1M_BLOCK_SIZE 16
+static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
+        float    * scales,
+        float    * weight,
+        float    * sumx,
+        float    * sumw,
+        float    * pairs,
+        int8_t   * L,
+        uint16_t * index,
+        int8_t   * shifts) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    block_iq1_s * y = vy;
+
+    const int64_t nbl = n/QK_K;
+
+    const int block_size = IQ1S_BLOCK_SIZE;
+
+    const float x_p[3] = {-1 + IQ1S_DELTA,  IQ1S_DELTA, 1 + IQ1S_DELTA};
+    const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
+
+
+    int * idx = (int *)(pairs + 1);
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+        memset(y[ibl].qs, 0, QK_K/8);
+        memset(y[ibl].qh, 0, QK_K/16);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            const float * xb = xbl + block_size*ib;
+            const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+            for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            float max = fabsf(xb[0]);
+            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
+            if (max < GROUP_MAX_EPS_IQ1_S) {
+                scales[ib] = 0;
+                memset(L, 1, block_size);
+                continue;
+            }
+            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
+            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
+            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
+            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
+            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
+            // for each possible and score for each split.
+            for (int j = 0; j < block_size; ++j) {
+                pairs[2*j] = xb[j];
+                idx[2*j] = j;
+            }
+            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
+            {
+                sumx[0] = sumw[0] = 0;
+                for (int j = 0; j < block_size; ++j) {
+                    int i = idx[2*j];
+                    sumx[j+1] = sumx[j] + weight[i]*xb[i];
+                    sumw[j+1] = sumw[j] + weight[i];
+                }
+            }
+            float best_score = -FLT_MIN, scale = max;
+            int besti1 = -1, besti2 = -1, best_shift = 0;
+            for (int i1 = 0; i1 <= block_size; ++i1) {
+                for (int i2 = i1; i2 <= block_size; ++i2) {
+                    float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[block_size] - sumx[i2])*x_p[2];
+                    float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[block_size] - sumw[i2])*x_p[2]*x_p[2];
+                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+                        scale = sumqx/sumq2; best_score = scale*sumqx;
+                        besti1 = i1; besti2 = i2; best_shift = 1;
+                    }
+                    sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[block_size] - sumx[i2])*x_m[2];
+                    sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[block_size] - sumw[i2])*x_m[2]*x_m[2];
+                    if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
+                        scale = sumqx/sumq2; best_score = scale*sumqx;
+                        besti1 = i1; besti2 = i2; best_shift = -1;
+                    }
+                }
+            }
+            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
+            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
+            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
+            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
+            if (scale < 0) {
+                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
+                scale = -scale; best_shift = -best_shift;
+            }
+            bool all_on_grid = true;
+            const float * xx = best_shift == 1 ? x_p : x_m;
+            for (int k = 0; k < block_size/8; ++k) {
+                uint16_t u = 0;
+                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    all_on_grid = false;
+                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
+                    GGML_ASSERT(grid_index >= 0);
+                }
+                index[k] = grid_index;
+            }
+            if (!all_on_grid) {
+                float sumqx = 0, sumq2 = 0;
+                for (int k = 0; k < block_size/8; ++k) {
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
+                    for (int j = 0; j < 8; ++j) {
+                        float w = weight[8*k + j];
+                        float q = xx[(pg[j] - 1)/2];
+                        sumqx += w*q*xb[8*k+j];
+                        sumq2 += w*q*q;
+                    }
+                }
+                if (sumqx > 0 && sumq2 > 0) scale = sumqx/sumq2;
+            }
+            uint16_t h = 0;
+            for (int k = 0; k < block_size/8; ++k) {
+                y[ibl].qs[(block_size/8)*ib + k] = index[k] & 255;
+                h |= (index[k] >> 8) << 3*k;
+            }
+            y[ibl].qh[ib] = h;
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            shifts[ib] = best_shift;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            continue;
+        }
+
+        float d = max_scale/15;
+        y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.125f is another fudge factor. Don't ask me why it is needed.
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(7, l));
+            if (shifts[ib] == -1) l |= 8;
+            y[ibl].qh[ib] |= (l << 12);
+        }
+    }
+}
+
+size_t quantize_iq1_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    float  scales[QK_K/IQ1S_BLOCK_SIZE];
+    float  weight[IQ1S_BLOCK_SIZE];
+    int8_t L[IQ1S_BLOCK_SIZE];
+    float  sumx[IQ1S_BLOCK_SIZE+1];
+    float  sumw[IQ1S_BLOCK_SIZE+1];
+    float  pairs[2*IQ1S_BLOCK_SIZE];
+    uint16_t index[IQ1S_BLOCK_SIZE/8];
+    int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights, scales, weight, sumx, sumw, pairs, L, index, shifts);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq1_s);
+    }
+    return nrow * nblock * sizeof(block_iq1_s);
+}
+
+static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights,
+        float    * scales,
+        float    * weight,
+        float    * pairs,
+        int8_t   * L,
+        uint16_t * index,
+        int8_t   * shifts) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ1_M);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    //GGML_ASSERT(quant_weights   && "missing quantization weights");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    block_iq1_m * y = vy;
+
+    const int64_t nbl = n/QK_K;
+
+    const int block_size = IQ1M_BLOCK_SIZE;
+
+    const float x_p[3] = {-1 + IQ1M_DELTA,  IQ1M_DELTA, 1 + IQ1M_DELTA};
+    const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
+    const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
+
+    int * idx = (int *)(pairs + 1);
+
+    float sumqx[4], sumq2[4];
+
+    iq1m_scale_t s;
+    const float * xx;
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+        memset(y[ibl].qs, 0, QK_K/8);
+        memset(y[ibl].qh, 0, QK_K/16);
+        memset(y[ibl].scales, 0, QK_K/32);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            const float * xb = xbl + block_size*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+            }
+            float max = fabsf(xb[0]);
+            for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
+            if (max < GROUP_MAX_EPS_IQ1_M) {
+                scales[ib] = 0;
+                memset(L, 1, block_size);
+                continue;
+            }
+            // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
+            // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
+            // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
+            // in ascending order, compute Si = sum[weight[j] xb[j], j = 0...i] and
+            // Wi = sum[weight[j], j = 0...i], and use these to quckly get get the optimum scale
+            // for each possible and score for each split.
+            for (int j = 0; j < block_size; ++j) {
+                pairs[2*j] = xb[j];
+                idx[2*j] = j;
+            }
+            qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
+            float best_score = -FLT_MIN, scale = max;
+            int besti1 = -1, besti2 = -1, best_k = -1;
+            // 0: +, +
+            // 1: +, -
+            // 2: -, +
+            // 3: -, -
+            for (int i1 = 0; i1 <= block_size; ++i1) {
+                for (int i2 = i1; i2 <= block_size; ++i2) {
+                    memset(sumqx, 0, 4*sizeof(float));
+                    memset(sumq2, 0, 4*sizeof(float));
+                    for (int j = 0; j < i1; ++j) {
+                        int i = idx[2*j];
+                        if (i < block_size/2) {
+                            sumqx[0] += weight[i]*x_p[0]*xb[i];
+                            sumqx[1] += weight[i]*x_p[0]*xb[i];
+                            sumqx[2] += weight[i]*x_m[0]*xb[i];
+                            sumqx[3] += weight[i]*x_m[0]*xb[i];
+                            sumq2[0] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[1] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[2] += weight[i]*x_m[0]*x_m[0];
+                            sumq2[3] += weight[i]*x_m[0]*x_m[0];
+                        } else {
+                            sumqx[0] += weight[i]*x_p[0]*xb[i];
+                            sumqx[2] += weight[i]*x_p[0]*xb[i];
+                            sumqx[1] += weight[i]*x_m[0]*xb[i];
+                            sumqx[3] += weight[i]*x_m[0]*xb[i];
+                            sumq2[0] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[2] += weight[i]*x_p[0]*x_p[0];
+                            sumq2[1] += weight[i]*x_m[0]*x_m[0];
+                            sumq2[3] += weight[i]*x_m[0]*x_m[0];
+                        }
+                    }
+                    for (int j = i1; j < i2; ++j) {
+                        int i = idx[2*j];
+                        if (i < block_size/2) {
+                            sumqx[0] += weight[i]*x_p[1]*xb[i];
+                            sumqx[1] += weight[i]*x_p[1]*xb[i];
+                            sumqx[2] += weight[i]*x_m[1]*xb[i];
+                            sumqx[3] += weight[i]*x_m[1]*xb[i];
+                            sumq2[0] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[1] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[2] += weight[i]*x_m[1]*x_m[1];
+                            sumq2[3] += weight[i]*x_m[1]*x_m[1];
+                        } else {
+                            sumqx[0] += weight[i]*x_p[1]*xb[i];
+                            sumqx[2] += weight[i]*x_p[1]*xb[i];
+                            sumqx[1] += weight[i]*x_m[1]*xb[i];
+                            sumqx[3] += weight[i]*x_m[1]*xb[i];
+                            sumq2[0] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[2] += weight[i]*x_p[1]*x_p[1];
+                            sumq2[1] += weight[i]*x_m[1]*x_m[1];
+                            sumq2[3] += weight[i]*x_m[1]*x_m[1];
+                        }
+                    }
+                    for (int j = i2; j < block_size; ++j) {
+                        int i = idx[2*j];
+                        if (i < block_size/2) {
+                            sumqx[0] += weight[i]*x_p[2]*xb[i];
+                            sumqx[1] += weight[i]*x_p[2]*xb[i];
+                            sumqx[2] += weight[i]*x_m[2]*xb[i];
+                            sumqx[3] += weight[i]*x_m[2]*xb[i];
+                            sumq2[0] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[1] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[2] += weight[i]*x_m[2]*x_m[2];
+                            sumq2[3] += weight[i]*x_m[2]*x_m[2];
+                        } else {
+                            sumqx[0] += weight[i]*x_p[2]*xb[i];
+                            sumqx[2] += weight[i]*x_p[2]*xb[i];
+                            sumqx[1] += weight[i]*x_m[2]*xb[i];
+                            sumqx[3] += weight[i]*x_m[2]*xb[i];
+                            sumq2[0] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[2] += weight[i]*x_p[2]*x_p[2];
+                            sumq2[1] += weight[i]*x_m[2]*x_m[2];
+                            sumq2[3] += weight[i]*x_m[2]*x_m[2];
+                        }
+                    }
+                    for (int k = 0; k < 4; ++k) {
+                        if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
+                            scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
+                            besti1 = i1; besti2 = i2; best_k = k;
+                        }
+                    }
+                }
+            }
+            GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0);
+            for (int j =      0; j < besti1; ++j) L[idx[2*j]] = 0;
+            for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
+            for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2;
+            if (scale < 0) {
+                for (int j = 0; j < block_size; ++j) L[j] = 2 - L[j];
+                scale = -scale;
+                best_k = best_k == 0 ? 3 : best_k == 1 ? 2 : best_k == 2 ? 1 : 0;
+            }
+            bool all_on_grid = true;
+            for (int k = 0; k < block_size/8; ++k) {
+                if (k == 0) xx = best_k < 2 ? x_p : x_m;
+                else xx = best_k%2 == 0 ? x_p : x_m;
+                uint16_t u = 0;
+                for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    all_on_grid = false;
+                    const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                    grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
+                    GGML_ASSERT(grid_index >= 0);
+                }
+                index[k] = grid_index;
+            }
+            if (!all_on_grid) {
+                float sumqx_f = 0, sumq2_f = 0;
+                for (int k = 0; k < block_size/8; ++k) {
+                    if (k == 0) xx = best_k < 2 ? x_p : x_m;
+                    else xx = best_k%2 == 0 ? x_p : x_m;
+                    const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
+                    for (int j = 0; j < 8; ++j) {
+                        float w = weight[8*k + j];
+                        float q = xx[(pg[j] - 1)/2];
+                        sumqx_f += w*q*xb[8*k+j];
+                        sumq2_f += w*q*q;
+                    }
+                }
+                if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
+            }
+            y[ibl].qs[2*ib + 0] = index[0] & 255;
+            y[ibl].qs[2*ib + 1] = index[1] & 255;
+            y[ibl].qh[ib] = (index[0] >> 8) | ((index[1] >> 8) << 4);
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            shifts[ib] = best_k;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            continue;
+        }
+
+        uint16_t * sc = (uint16_t *)y[ibl].scales;
+        float d = max_scale/15;
+        float id = 1/d;
+        float sumqx_f = 0, sumq2_f = 0;
+        for (int ib = 0; ib < QK_K/block_size; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib+0]-1));
+            l = MAX(0, MIN(7, l));
+            sc[ib/4] |= (l << 3*(ib%4));
+            y[ibl].qh[ib] |= masks[shifts[ib]];
+            const float * xb = xbl + block_size*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + block_size*ib;
+                for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
+            }
+            for (int k = 0; k < block_size/8; ++k) {
+                if (k == 0) xx = shifts[ib] < 2 ? x_p : x_m;
+                else xx = shifts[ib]%2 == 0 ? x_p : x_m;
+                const int8_t * pg = (const int8_t *)(kgrid_q2xs + y[ibl].qs[2*ib+k] + ((y[ibl].qh[ib] << (8 - 4*k)) & 0x700));
+                for (int j = 0; j < 8; ++j) {
+                    float w = weight[8*k + j];
+                    float q = xx[(pg[j] - 1)/2]*(2*l+1);
+                    sumqx_f += w*q*xb[8*k+j];
+                    sumq2_f += w*q*q;
+                }
+            }
+        }
+        if (sumq2_f > 0) d = sumqx_f/sumq2_f;
+        s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
+        sc[0] |= ((s.u16 & 0x000f) << 12);
+        sc[1] |= ((s.u16 & 0x00f0) <<  8);
+        sc[2] |= ((s.u16 & 0x0f00) <<  4);
+        sc[3] |= ((s.u16 & 0xf000) <<  0);
+    }
+}
+
+size_t quantize_iq1_m(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    float  scales[QK_K/IQ1M_BLOCK_SIZE];
+    float  weight[IQ1M_BLOCK_SIZE];
+    int8_t L[IQ1M_BLOCK_SIZE];
+    float  pairs[2*IQ1M_BLOCK_SIZE];
+    uint16_t index[IQ1M_BLOCK_SIZE/8];
+    int8_t shifts[QK_K/IQ1M_BLOCK_SIZE];
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq1_m_impl(src, qrow, n_per_row, quant_weights, scales, weight, pairs, L, index, shifts);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq1_m);
+    }
+    return nrow * nblock * sizeof(block_iq1_m);
+}
+
+// ============================ 4-bit non-linear quants
+
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * restrict x,
+        ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
+        float * scales, float * weight, uint8_t * L,
+        const int8_t * values,
+        const float * quant_weights,
+        const int ntry) {
+
+    float sigma2 = 0;
+    for (int j = 0; j < super_block_size; ++j) sigma2 += x[j]*x[j];
+    sigma2 *= 2.f/super_block_size;
+
+    memset(q4, 0, super_block_size/2);
+    dh[0] = GGML_FP32_TO_FP16(0.f);
+
+    float max_scale = 0, amax_scale = 0;
+    for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+        const float * xb = x + ib*block_size;
+        uint8_t * Lb = L + ib*block_size;
+        if (quant_weights) {
+            const float * qw = quant_weights + ib*block_size;
+            for (int j = 0; j < block_size; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
+        } else {
+            for (int j = 0; j < block_size; ++j) weight[j] = xb[j]*xb[j];
+        }
+        float amax = 0, max = 0;
+        for (int j = 0; j < block_size; ++j) {
+            float ax = fabsf(xb[j]);
+            if (ax > amax) {
+                amax = ax; max = xb[j];
+            }
+        }
+        if (amax < GROUP_MAX_EPS) {
+            scales[ib] = 0;
+            continue;
+        }
+        float d = ntry > 0 ? -max/values[0] : max/values[0];
+        float id = 1/d;
+        float sumqx = 0, sumq2 = 0;
+        for (int j = 0; j < block_size; ++j) {
+            float al = id*xb[j];
+            int l = best_index_int8(16, values, al);
+            Lb[j] = l;
+            float q = values[l];
+            float w = weight[j];
+            sumqx += w*q*xb[j];
+            sumq2 += w*q*q;
+        }
+        d = sumqx/sumq2;
+        float best = d*sumqx;
+        for (int itry = -ntry; itry <= ntry; ++itry) {
+            id = (itry + values[0])/max;
+            sumqx = sumq2 = 0;
+            for (int j = 0; j < block_size; ++j) {
+                float al = id*xb[j];
+                int l = best_index_int8(16, values, al);
+                float q = values[l];
+                float w = weight[j];
+                sumqx += w*q*xb[j];
+                sumq2 += w*q*q;
+            }
+            if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                d = sumqx/sumq2; best = d * sumqx;
+            }
+        }
+        scales[ib] = d;
+        float abs_d = fabsf(d);
+        if (abs_d > amax_scale) {
+            amax_scale = abs_d; max_scale = d;
+        }
+    }
+
+    if (super_block_size/block_size > 1) {
+        int nb = super_block_size/block_size;
+        memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
+        float d = -max_scale/32;
+        dh[0] = GGML_FP32_TO_FP16(d);
+        float id = d ? 1/d : 0.f;
+        for (int ib = 0; ib < super_block_size/block_size; ++ib) {
+            int l = nearest_int(id*scales[ib]);
+            l = MAX(-32, MIN(31, l));
+            float dl = d * l;
+            float idl = dl ? 1/dl : 0.f;
+            uint8_t * Lb = L + ib*block_size;
+            const float * xb = x + ib*block_size;
+            for (int j = 0; j < block_size; ++j) {
+                Lb[j] = best_index_int8(16, values, idl*xb[j]);
+            }
+            l += 32;
+            uint8_t l_l = l & 0xf;
+            uint8_t l_h = l >>  4;
+            if (ib%2 == 0) scales_l[ib/2] = l_l;
+            else scales_l[ib/2] |= (l_l << 4);
+            scales_h[ib/8] |= (l_h << 2*(ib%8));
+        }
+    } else {
+        dh[0] = GGML_FP32_TO_FP16(scales[0]);
+        if (ntry > 0) {
+            float id = scales[0] ? 1/scales[0] : 0;
+            for (int j = 0; j < super_block_size; ++j) {
+                L[j] = best_index_int8(16, values, id*x[j]);
+            }
+        }
+    }
+
+    for (int i = 0; i < super_block_size/32; ++i) {
+        for (int j = 0; j < 16; ++j) {
+            q4[16*i + j] = L[32*i + j] | (L[32*i + 16 + j] << 4);
+        }
+    }
+}
+
+size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK4_NL == 0);
+    int64_t nblock = n_per_row/QK4_NL;
+    char * qrow = (char *)dst;
+    uint8_t L[QK4_NL];
+    float weight[QK4_NL];
+    uint16_t unused_h;
+    uint8_t * unused_l = NULL;
+    float scale;
+    for (int64_t row = 0; row < nrow; ++row) {
+        block_iq4_nl * iq4 = (block_iq4_nl *)qrow;
+        for (int ibl = 0; ibl < nblock; ++ibl) {
+            const float * qw = quant_weights ? quant_weights + QK4_NL*ibl : NULL;
+            quantize_row_iq4_nl_impl(QK4_NL, 32, src + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
+                    &scale, weight, L, kvalues_iq4nl, qw, 7);
+        }
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq4_nl);
+    }
+    return nrow * nblock * sizeof(block_iq4_nl);
+}
+
+//void quantize_row_iq4_nl_ref(const float * restrict x, void * restrict vy, int64_t k) {
+void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
+    GGML_ASSERT(k%QK4_NL == 0);
+    int64_t nblock = k/QK4_NL;
+    uint8_t L[QK4_NL];
+    float weight[QK4_NL];
+    uint16_t unused_h;
+    uint8_t * unused_l = NULL;
+    float scale;
+    block_iq4_nl * iq4 = y;
+    for (int ibl = 0; ibl < nblock; ++ibl) {
+        quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l,
+                &scale, weight, L, kvalues_iq4nl, NULL, -1);
+    }
+}
+
+size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    uint8_t L[QK_K];
+    float weight[32];
+    float scales[QK_K/32];
+    for (int64_t row = 0; row < nrow; ++row) {
+        block_iq4_xs * iq4 = (block_iq4_xs *)qrow;
+        for (int ibl = 0; ibl < nblock; ++ibl) {
+            const float * qw = quant_weights ? quant_weights + QK_K*ibl : NULL;
+            quantize_row_iq4_nl_impl(QK_K, 32, src + QK_K*ibl, &iq4[ibl].d, iq4[ibl].qs, &iq4[ibl].scales_h, iq4[ibl].scales_l,
+                    scales, weight, L, kvalues_iq4nl, qw, 7);
+        }
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq4_xs);
+    }
+    return nrow * nblock * sizeof(block_iq4_xs);
+}
+
+void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    quantize_iq4_xs(x, y, 1, k, NULL);
+}
+
+// =============================== 2.5625 bpw
+
+static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy, int64_t n, const float * restrict quant_weights) {
+
+    const int gindex = iq2_data_index(GGML_TYPE_IQ2_S);
+
+    const uint64_t * kgrid_q2xs      = iq2_data[gindex].grid;
+    const int      * kmap_q2xs       = iq2_data[gindex].map;
+    const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
+
+    GGML_ASSERT(kmap_q2xs       && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kgrid_q2xs      && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
+    GGML_ASSERT(n%QK_K == 0);
+
+    const int kMaxQ = 3;
+
+    const int64_t nbl = n/QK_K;
+
+    block_iq2_s * y = vy;
+
+    float scales[QK_K/16];
+    float weight[16];
+    float xval[16];
+    int8_t L[16];
+    int8_t Laux[16];
+    float  waux[16];
+    bool   is_on_grid[2];
+    bool   is_on_grid_aux[2];
+    uint8_t block_signs[2];
+
+    for (int ibl = 0; ibl < nbl; ++ibl) {
+
+        memset(&y[ibl], 0, sizeof(block_iq2_s));
+        y[ibl].d = GGML_FP32_TO_FP16(0.f);
+
+        float max_scale = 0;
+
+        const float * xbl = x + QK_K*ibl;
+        float sumx2 = 0;
+        for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
+        float sigma2 = 2*sumx2/QK_K;
+
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            const float * xb = xbl + 16*ib;
+            if (quant_weights) {
+                const float * qw = quant_weights + QK_K*ibl + 16*ib;
+                for (int i = 0; i < 16; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
+            } else {
+                for (int i = 0; i < 16; ++i) weight[i] = 0.25f*sigma2 + xb[i]*xb[i];
+            }
+            for (int i = 0; i < 16; ++i) waux[i] = sqrtf(weight[i]);
+            for (int k = 0; k < 2; ++k) {
+                uint8_t s = 0;
+                for (int i = 0; i < 8; ++i) {
+                    if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i];
+                    else {
+                        xval[8*k + i] = -xb[8*k + i]; s |= (1 << i);
+                    }
+                }
+                block_signs[k] = s;
+            }
+            float max = xval[0];
+            for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]);
+            if (max < GROUP_MAX_EPS_IQ2_S) {
+                scales[ib] = 0;
+                continue;
+            }
+            float best = 0;
+            float scale = max/(2*kMaxQ-1);
+            is_on_grid[0] = is_on_grid[1] = true;
+            for (int is = -9; is <= 9; ++is) {
+                float id = (2*kMaxQ-1+is*0.1f)/max;
+                float this_scale = 1/id;
+                for (int k = 0; k < 2; ++k) {
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        Laux[8*k+i] = MAX(0, MIN(kMaxQ-1, l));
+                    }
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
+                    int grid_index = kmap_q2xs[u];
+                    is_on_grid_aux[k] = true;
+                    if (grid_index < 0) {
+                        is_on_grid_aux[k] = false;
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*Laux[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
+                    scale = sumqx/sumq2; best = scale*sumqx;
+                    for (int i = 0; i < 16; ++i) L[i] = Laux[i];
+                    for (int k = 0; k <  2; ++k) is_on_grid[k] = is_on_grid_aux[k];
+                }
+            }
+            int n_not_ongrid = 0;
+            for (int k = 0; k < 2; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
+            if (n_not_ongrid > 0 && scale > 0) {
+                float id = 1/scale;
+                for (int k = 0; k < 2; ++k) {
+                    if (is_on_grid[k]) continue;
+                    uint16_t u = 0;
+                    for (int i = 0; i < 8; ++i) {
+                        int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
+                        l = MAX(0, MIN(kMaxQ-1, l));
+                        u |= (l << 2*i);
+                        L[8*k + i] = l;
+                    }
+                    int grid_index = kmap_q2xs[u];
+                    if (grid_index < 0) {
+                        const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
+                        grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, scale, L + 8*k);
+                    }
+                }
+                float sumqx = 0, sumq2 = 0;
+                for (int i = 0; i < 16; ++i) {
+                    float w = weight[i];
+                    float q = 2*L[i] + 1;
+                    sumqx += w*xval[i]*q;
+                    sumq2 += w*q*q;
+                }
+                if (sumq2 > 0) scale = sumqx/sumq2;
+            }
+            if (scale < 0) {
+                scale = -scale;
+                for (int k = 0; k < 2; ++k) block_signs[k] = ~block_signs[k];
+            }
+            for (int k = 0; k < 2; ++k) {
+                uint16_t u = 0;
+                for (int i = 0; i < 8; ++i) u |= (L[8*k+i] << 2*i);
+                int grid_index = kmap_q2xs[u];
+                if (grid_index < 0) {
+                    printf("Oops: found point %u not on grid:", u);
+                    for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]);
+                    printf("\n");
+                    GGML_ABORT("fatal error");
+                }
+                const int i8 = 2*ib + k;
+                y[ibl].qs[i8] = grid_index & 255;
+                y[ibl].qh[i8/4] |= ((grid_index >> 8) << 2*(i8%4));
+                y[ibl].qs[QK_K/8 + i8] = block_signs[k];
+            }
+            GGML_ASSERT(scale >= 0);
+            scales[ib] = scale;
+            max_scale = MAX(max_scale, scale);
+        }
+
+        if (!max_scale) {
+            continue;
+        }
+
+        float d = max_scale/31;
+        y[ibl].d = GGML_FP32_TO_FP16(d * 0.9875f);
+        float id = 1/d;
+        for (int ib = 0; ib < QK_K/16; ++ib) {
+            int l = nearest_int(0.5f*(id*scales[ib]-1));
+            l = MAX(0, MIN(15, l));
+            if (ib%2 == 0) y[ibl].scales[ib/2] = l;
+            else y[ibl].scales[ib/2] |= (l << 4);
+        }
+    }
+}
+
+size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_ASSERT(n_per_row%QK_K == 0);
+    int64_t nblock = n_per_row/QK_K;
+    char * qrow = (char *)dst;
+    for (int64_t row = 0; row < nrow; ++row) {
+        quantize_row_iq2_s_impl(src, qrow, n_per_row, quant_weights);
+        src += n_per_row;
+        qrow += nblock*sizeof(block_iq2_s);
+    }
+    return nrow * nblock * sizeof(block_iq2_s);
+}
+
+void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
+    assert(k % QK_K == 0);
+    quantize_iq2_s(x, y, 1, k, NULL);
+}
+
+// =============================== data validation
+
+static bool validate_float(float f, size_t i) {
+    if (isinf(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
+        return false;
+    }
+
+    if (isnan(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
+        return false;
+    }
+
+    return true;
+}
+
+static bool isinf_fp16(ggml_fp16_t f) {
+    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0;
+}
+
+static bool isnan_fp16(ggml_fp16_t f) {
+    return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
+}
+
+static bool validate_fp16(ggml_fp16_t f, size_t i) {
+    if (isinf_fp16(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
+        return false;
+    }
+
+    if (isnan_fp16(f)) {
+        fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i);
+        return false;
+    }
+
+    return true;
+}
+
+#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        if (!validate_fp16(q[i].d, i)) { \
+            return false; \
+        } \
+    }
+
+#define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \
+            return false; \
+        } \
+    }
+
+#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        for (size_t j = 0; j < (nr); ++j) { \
+            if (!validate_fp16(q[i].d[j], i)) { \
+                return false; \
+            } \
+        } \
+    }
+
+bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
+    if (type < 0 || type >= GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid type %d\n", __func__, type);
+        return false;
+    }
+
+    if (nbytes % ggml_type_size(type) != 0) {
+        fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type));
+        return false;
+    }
+
+    const size_t nb = nbytes/ggml_type_size(type);
+
+    switch (type) {
+        case GGML_TYPE_BF16:
+            {
+                int nans = 0;
+                int infs = 0;
+                const unsigned short * f = (const unsigned short *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                    nans += (f[i] & 0x7fff) > 0x7f80;
+                    infs += (f[i] & 0x7fff) == 0x7f80;
+                }
+                if (nans) {
+                    fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
+                    return false;
+                }
+                if (infs) {
+                    fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
+                    return false;
+                }
+            } break;
+        case GGML_TYPE_F16:
+            {
+                const ggml_fp16_t * f = (const ggml_fp16_t *) data;
+                size_t i = 0;
+#if defined(__AVX2__)
+                for (; i + 15 < nb; i += 16) {
+                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
+                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
+                    __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00));
+                    int mask = _mm256_movemask_epi8(cmp);
+                    if (mask) {
+                        for (size_t j = 0; j < 16; ++j) {
+                            if (!validate_fp16(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#elif defined(__ARM_NEON)
+                for (; i + 7 < nb; i += 8) {
+                    uint16x8_t v = vld1q_u16(f + i);
+                    uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
+                    uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
+                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
+                    if (mask) {
+                        for (size_t j = 0; j < 8; ++j) {
+                            if (!validate_fp16(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#endif
+                for (; i < nb; ++i) {
+                    if (!validate_fp16(f[i], i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_F32:
+            {
+                const float * f = (const float *) data;
+                size_t i = 0;
+#if defined(__AVX2__)
+                for (; i + 7 < nb; i += 8) {
+                    __m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
+                    __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
+                    __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000));
+                    int mask = _mm256_movemask_epi8(cmp);
+                    if (mask) {
+                        for (size_t j = 0; j < 8; ++j) {
+                            if (!validate_float(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#elif defined(__ARM_NEON)
+                for (; i + 3 < nb; i += 4) {
+                    uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
+                    uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
+                    uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
+                    uint64_t mask = vget_lane_u64(vreinterpret_u64_u16(vshrn_n_u32(cmp, 8)), 0);
+                    if (mask) {
+                        for (size_t j = 0; j < 4; ++j) {
+                            if (!validate_float(f[i + j], i + j)) {
+                                return false;
+                            }
+                        }
+                        GGML_UNREACHABLE();
+                    }
+                }
+#endif
+                for (; i < nb; ++i) {
+                    if (!validate_float(f[i], i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_F64:
+            {
+                const double * f = (const double *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                    if (!validate_float(f[i], i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_Q4_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
+            } break;
+        case GGML_TYPE_Q4_1:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m);
+            } break;
+        case GGML_TYPE_Q5_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_0, data, nb);
+            } break;
+        case GGML_TYPE_Q5_1:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_1, data, nb, d, m);
+            } break;
+        case GGML_TYPE_Q8_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
+            } break;
+        case GGML_TYPE_Q2_K:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
+            } break;
+        case GGML_TYPE_Q3_K:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q3_K, data, nb);
+            } break;
+        case GGML_TYPE_Q4_K:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin);
+            } break;
+        case GGML_TYPE_Q5_K:
+            {
+                VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin);
+            } break;
+        case GGML_TYPE_Q6_K:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_q6_K, data, nb);
+            } break;
+        case GGML_TYPE_Q8_K:
+            {
+                const block_q8_K * q = (const block_q8_K *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                    if (!validate_float(q[i].d, i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_TQ1_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb);
+            } break;
+        case GGML_TYPE_TQ2_0:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb);
+            } break;
+        case GGML_TYPE_IQ1_S:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb);
+            } break;
+        case GGML_TYPE_IQ1_M:
+            {
+                const block_iq1_m * q = (const block_iq1_m *) data;
+                for (size_t i = 0; i < nb; ++i) {
+                    iq1m_scale_t scale;
+                    const uint16_t * sc = (const uint16_t *)q[i].scales;
+                    scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+                    if (!validate_fp16(scale.f16, i)) {
+                        return false;
+                    }
+                }
+            } break;
+        case GGML_TYPE_IQ2_XXS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xxs, data, nb);
+            } break;
+        case GGML_TYPE_IQ2_XS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_xs, data, nb);
+            } break;
+        case GGML_TYPE_IQ2_S:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq2_s, data, nb);
+            } break;
+        case GGML_TYPE_IQ3_XXS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_xxs, data, nb);
+            } break;
+
+        case GGML_TYPE_IQ3_S:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb);
+            } break;
+        case GGML_TYPE_IQ4_XS:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb);
+            } break;
+        case GGML_TYPE_IQ4_NL:
+            {
+                VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
+            } break;
+
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_I64:
+            // nothing to validate
+            break;
+        default:
+            {
+                fprintf(stderr, "%s: invalid type %d\n", __func__, type);
+                return false;
+            }
+    }
+
+    return true;
+}
diff --git a/llama/ggml-quants.h b/llama/ggml-quants.h
new file mode 100644
index 000000000..cf518ba08
--- /dev/null
+++ b/llama/ggml-quants.h
@@ -0,0 +1,126 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#define GGML_COMMON_DECL_C
+#include "ggml-common.h"
+
+#include "ggml.h"
+
+// GGML internal header
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+// NOTE: these functions are defined as GGML_API because they used by the CPU backend
+
+// Quantization
+GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+
+GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
+
+GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k);
+
+GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl  * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs  * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_iq3_s_ref  (const float * GGML_RESTRICT x, block_iq3_s   * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_iq2_s_ref  (const float * GGML_RESTRICT x, block_iq2_s   * GGML_RESTRICT y, int64_t k);
+
+// Dequantization
+GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq2_xs (const block_iq2_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq2_s  (const block_iq2_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq1_s  (const block_iq1_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq1_m  (const block_iq1_m   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq4_nl (const block_iq4_nl  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq4_xs (const block_iq4_xs  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_iq3_s  (const block_iq3_s   * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
+// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
+GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq2_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq1_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq1_m  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_iq3_s  (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
+GGML_API void iq2xs_init_impl(enum ggml_type type);
+GGML_API void iq2xs_free_impl(enum ggml_type type);
+GGML_API void iq3xs_init_impl(int grid_size);
+GGML_API void iq3xs_free_impl(int grid_size);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml-threading.cpp b/llama/ggml-threading.cpp
new file mode 100644
index 000000000..7559b3366
--- /dev/null
+++ b/llama/ggml-threading.cpp
@@ -0,0 +1,38 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "ggml-threading.h"
+#include 
+
+std::mutex ggml_critical_section_mutex;
+
+void ggml_critical_section_start() {
+    ggml_critical_section_mutex.lock();
+}
+
+void ggml_critical_section_end(void) {
+    ggml_critical_section_mutex.unlock();
+}
diff --git a/llama/ggml-threading.h b/llama/ggml-threading.h
new file mode 100644
index 000000000..fe2ce3679
--- /dev/null
+++ b/llama/ggml-threading.h
@@ -0,0 +1,40 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+GGML_API void ggml_critical_section_start(void);
+GGML_API void ggml_critical_section_end(void);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/llama/ggml.c b/llama/ggml.c
new file mode 100644
index 000000000..8d442e08a
--- /dev/null
+++ b/llama/ggml.c
@@ -0,0 +1,7758 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows
+#define _USE_MATH_DEFINES // For M_PI on MSVC
+
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "ggml-threading.h"
+#include "ggml.h"
+
+// FIXME: required here for quantization functions
+#include "ggml-quants.h"
+
+#ifdef GGML_USE_CPU_HBM
+#include 
+#endif
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include  // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
+#include 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#if defined(__gnu_linux__)
+#include 
+#endif
+
+#if defined(__APPLE__)
+#include 
+#include 
+#include 
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+    #define NOMINMAX
+#endif
+#include 
+#endif
+
+#define UNUSED GGML_UNUSED
+
+#if defined(_MSC_VER)
+#define m512bh(p) p
+#define m512i(p) p
+#else
+#define m512bh(p) (__m512bh)(p)
+#define m512i(p) (__m512i)(p)
+#endif
+
+// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
+float ggml_table_f32_f16[1 << 16];
+
+#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \
+    (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH))
+#include 
+#include 
+#include 
+#include 
+
+#if defined(__ANDROID__)
+#include 
+#include 
+#include 
+
+struct backtrace_state {
+    void ** current;
+    void ** end;
+};
+
+static _Unwind_Reason_Code unwind_callback(struct _Unwind_Context* context, void* arg) {
+    struct backtrace_state * state = (struct backtrace_state *)arg;
+    uintptr_t pc = _Unwind_GetIP(context);
+    if (pc) {
+        if (state->current == state->end) {
+            return _URC_END_OF_STACK;
+        } else {
+            *state->current++ = (void*)pc;
+        }
+    }
+    return _URC_NO_REASON;
+}
+
+static void ggml_print_backtrace_symbols(void) {
+    const int max = 100;
+    void* buffer[max];
+
+    struct backtrace_state state = {buffer, buffer + max};
+    _Unwind_Backtrace(unwind_callback, &state);
+
+    int count = state.current - buffer;
+
+    for (int idx = 0; idx < count; ++idx) {
+        const void * addr = buffer[idx];
+        const char * symbol = "";
+
+        Dl_info info;
+        if (dladdr(addr, &info) && info.dli_sname) {
+            symbol = info.dli_sname;
+        }
+
+        fprintf(stderr, "%d: %p %s\n", idx, addr, symbol);
+    }
+}
+#elif defined(__linux__) && defined(__GLIBC__)
+#include 
+static void ggml_print_backtrace_symbols(void) {
+    void * trace[100];
+    int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0]));
+    backtrace_symbols_fd(trace, nptrs, STDERR_FILENO);
+}
+#else
+static void ggml_print_backtrace_symbols(void) {
+    // platform not supported
+}
+#endif
+
+static void ggml_print_backtrace(void) {
+    char attach[32];
+    snprintf(attach, sizeof(attach), "attach %d", getpid());
+    int pid = fork();
+    if (pid == 0) {
+        // try gdb
+        execlp("gdb", "gdb", "--batch",
+            "-ex", "set style enabled on",
+            "-ex", attach,
+            "-ex", "bt -frame-info source-and-location",
+            "-ex", "detach",
+            "-ex", "quit",
+            (char *) NULL);
+        // try lldb
+        execlp("lldb", "lldb", "--batch",
+            "-o", "bt",
+            "-o", "quit",
+            "-p", attach,
+            (char *) NULL);
+        exit(EXIT_FAILURE);
+    } else {
+        int wstatus;
+        waitpid(pid, &wstatus, 0);
+        if (WIFEXITED(wstatus)) {
+            if (WEXITSTATUS(wstatus) == EXIT_FAILURE) {
+                // gdb failed, fallback to backtrace_symbols
+                ggml_print_backtrace_symbols();
+            }
+        }
+    }
+}
+#else
+static void ggml_print_backtrace(void) {
+    // platform not supported
+}
+#endif
+
+void ggml_abort(const char * file, int line, const char * fmt, ...) {
+    fflush(stdout);
+
+    fprintf(stderr, "%s:%d: ", file, line);
+
+    va_list args;
+    va_start(args, fmt);
+    vfprintf(stderr, fmt, args);
+    va_end(args);
+
+    fprintf(stderr, "\n");
+
+    ggml_print_backtrace();
+    abort();
+}
+
+//
+// logging
+//
+
+struct ggml_logger_state {
+    ggml_log_callback log_callback;
+    void * log_callback_user_data;
+};
+static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
+
+static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
+    if (format == NULL) {
+        return;
+    }
+    va_list args_copy;
+    va_copy(args_copy, args);
+    char buffer[128];
+    int len = vsnprintf(buffer, 128, format, args);
+    if (len < 128) {
+        g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
+    } else {
+        char * buffer2 = (char *) calloc(len + 1, sizeof(char));
+        vsnprintf(buffer2, len + 1, format, args_copy);
+        buffer2[len] = 0;
+        g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
+        free(buffer2);
+    }
+    va_end(args_copy);
+}
+
+void ggml_log_internal(enum ggml_log_level level, const char * format, ...) {
+    va_list args;
+    va_start(args, format);
+    ggml_log_internal_v(level, format, args);
+    va_end(args);
+}
+
+void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
+//
+// end of logging block
+//
+
+#ifdef GGML_USE_ACCELERATE
+// uncomment to use vDSP for soft max computation
+// note: not sure if it is actually faster
+//#define GGML_SOFT_MAX_ACCELERATE
+#endif
+
+
+void * ggml_aligned_malloc(size_t size) {
+    const int alignment = 64;
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+    return _aligned_malloc(size, alignment);
+#else
+    if (size == 0) {
+        GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
+        return NULL;
+    }
+    void * aligned_memory = NULL;
+  #ifdef GGML_USE_CPU_HBM
+    int result = hbw_posix_memalign(&aligned_memory, alignment, size);
+  #elif TARGET_OS_OSX
+    GGML_UNUSED(alignment);
+    kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
+    int result = EFAULT;
+    switch (alloc_status) {
+        case KERN_SUCCESS:
+            result = 0;
+            break;
+        case KERN_INVALID_ADDRESS:
+            result = EINVAL;
+            break;
+        case KERN_NO_SPACE:
+            result = ENOMEM;
+            break;
+        default:
+            result = EFAULT;
+            break;
+    }
+  #else
+    int result = posix_memalign(&aligned_memory, alignment, size);
+  #endif
+    if (result != 0) {
+        // Handle allocation failure
+        const char *error_desc = "unknown allocation error";
+        switch (result) {
+            case EINVAL:
+                error_desc = "invalid alignment value";
+                break;
+            case ENOMEM:
+                error_desc = "insufficient memory";
+                break;
+        }
+        GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
+        return NULL;
+    }
+    return aligned_memory;
+#endif
+}
+
+void ggml_aligned_free(void * ptr, size_t size) {
+    GGML_UNUSED(size);
+#if defined(_MSC_VER) || defined(__MINGW32__)
+    _aligned_free(ptr);
+#elif GGML_USE_CPU_HBM
+    if (ptr != NULL) {
+        hbw_free(ptr);
+    }
+#elif TARGET_OS_OSX
+    if (ptr != NULL) {
+        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
+    }
+#else
+    free(ptr);
+#endif
+}
+
+
+inline static void * ggml_malloc(size_t size) {
+    if (size == 0) {
+        GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n");
+        return NULL;
+    }
+    void * result = malloc(size);
+    if (result == NULL) {
+        GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
+        GGML_ABORT("fatal error");
+    }
+    return result;
+}
+
+// calloc
+inline static void * ggml_calloc(size_t num, size_t size) {
+    if (num == 0 || size == 0) {
+        GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n");
+        return NULL;
+    }
+    void * result = calloc(num, size);
+    if (result == NULL) {
+        GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
+        GGML_ABORT("fatal error");
+    }
+    return result;
+}
+
+#define GGML_MALLOC(size)      ggml_malloc(size)
+#define GGML_CALLOC(num, size) ggml_calloc(num, size)
+
+#define GGML_FREE(ptr) free(ptr)
+
+const char * ggml_status_to_string(enum ggml_status status) {
+    switch (status) {
+        case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
+        case GGML_STATUS_FAILED:       return "GGML status: error (operation failed)";
+        case GGML_STATUS_SUCCESS:      return "GGML status: success";
+        case GGML_STATUS_ABORTED:      return "GGML status: warning (operation aborted)";
+    }
+
+    return "GGML status: unknown";
+}
+
+float ggml_fp16_to_fp32(ggml_fp16_t x) {
+#define ggml_fp16_to_fp32 do_not_use__ggml_fp16_to_fp32__in_ggml
+    return GGML_FP16_TO_FP32(x);
+}
+
+ggml_fp16_t ggml_fp32_to_fp16(float x) {
+#define ggml_fp32_to_fp16 do_not_use__ggml_fp32_to_fp16__in_ggml
+    return GGML_FP32_TO_FP16(x);
+}
+
+float ggml_bf16_to_fp32(ggml_bf16_t x) {
+#define ggml_bf16_to_fp32 do_not_use__ggml_bf16_to_fp32__in_ggml
+    return GGML_BF16_TO_FP32(x);  // it just left shifts
+}
+
+ggml_bf16_t ggml_fp32_to_bf16(float x) {
+#define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml
+    return GGML_FP32_TO_BF16(x);
+}
+
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
+    for (int64_t i = 0; i < n; i++) {
+        y[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+}
+
+// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
+//        currently, the ggml_cpu_has_* functions are entirely compile-time
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
+    int64_t i = 0;
+#if defined(__F16C__)
+    //if (ggml_cpu_has_f16c()) {
+        for (; i + 7 < n; i += 8) {
+            __m256 x_vec = _mm256_loadu_ps(x + i);
+            __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
+            _mm_storeu_si128((__m128i *)(y + i), y_vec);
+        }
+        for(; i + 3 < n; i += 4) {
+            __m128 x_vec = _mm_loadu_ps(x + i);
+            __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
+            _mm_storel_epi64((__m128i *)(y + i), y_vec);
+        }
+    //}
+#endif
+    for (; i < n; i++) {
+        y[i] = GGML_FP32_TO_FP16(x[i]);
+    }
+}
+
+void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
+    int64_t i = 0;
+#if defined(__AVX512F__)
+    //if (ggml_cpu_has_avx512()) {
+        for (; i + 16 <= n; i += 16) {
+            _mm512_storeu_ps(y + i,
+                            _mm512_castsi512_ps(
+                                _mm512_slli_epi32(
+                                    _mm512_cvtepu16_epi32(
+                                        _mm256_loadu_si256(
+                                            (const __m256i *)(x + i))),
+                                    16)));
+        }
+    //}
+#endif
+#if defined(__AVX2__)
+    //if (ggml_cpu_has_avx2()) {
+        for (; i + 8 <= n; i += 8) {
+            _mm256_storeu_ps(y + i,
+                            _mm256_castsi256_ps(
+                                _mm256_slli_epi32(
+                                    _mm256_cvtepu16_epi32(
+                                        _mm_loadu_si128(
+                                            (const __m128i *)(x + i))),
+                                    16)));
+        }
+    //}
+#endif
+    for (; i < n; i++) {
+        y[i] = GGML_BF16_TO_FP32(x[i]);
+    }
+}
+
+void ggml_fp32_to_bf16_row_ref(const float * x, ggml_bf16_t * y, int64_t n) {
+    for (int i = 0; i < n; i++) {
+        y[i] = ggml_compute_fp32_to_bf16(x[i]);
+    }
+}
+
+void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
+  int i = 0;
+#if defined(__AVX512BF16__)
+  // subnormals are flushed to zero on this platform
+  for (; i + 32 <= n; i += 32) {
+        _mm512_storeu_si512(
+            (__m512i *)(y + i),
+            m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
+                                _mm512_loadu_ps(x + i))));
+  }
+#endif
+    for (; i < n; i++) {
+        y[i] = GGML_FP32_TO_BF16(x[i]);
+    }
+}
+
+bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) {
+    return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0;
+}
+
+//
+// timing
+//
+
+#if defined(_MSC_VER) || defined(__MINGW32__)
+static int64_t timer_freq, timer_start;
+void ggml_time_init(void) {
+    LARGE_INTEGER t;
+    QueryPerformanceFrequency(&t);
+    timer_freq = t.QuadPart;
+
+    // The multiplication by 1000 or 1000000 below can cause an overflow if timer_freq
+    // and the uptime is high enough.
+    // We subtract the program start time to reduce the likelihood of that happening.
+    QueryPerformanceCounter(&t);
+    timer_start = t.QuadPart;
+}
+int64_t ggml_time_ms(void) {
+    LARGE_INTEGER t;
+    QueryPerformanceCounter(&t);
+    return ((t.QuadPart-timer_start) * 1000) / timer_freq;
+}
+int64_t ggml_time_us(void) {
+    LARGE_INTEGER t;
+    QueryPerformanceCounter(&t);
+    return ((t.QuadPart-timer_start) * 1000000) / timer_freq;
+}
+#else
+void ggml_time_init(void) {}
+int64_t ggml_time_ms(void) {
+    struct timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000;
+}
+
+int64_t ggml_time_us(void) {
+    struct timespec ts;
+    clock_gettime(CLOCK_MONOTONIC, &ts);
+    return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000;
+}
+#endif
+
+int64_t ggml_cycles(void) {
+    return clock();
+}
+
+int64_t ggml_cycles_per_ms(void) {
+    return CLOCKS_PER_SEC/1000;
+}
+
+//
+// cross-platform UTF-8 file paths
+//
+
+#ifdef _WIN32
+static wchar_t * ggml_mbstowcs(const char * mbs) {
+    int wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, NULL, 0);
+    if (!wlen) {
+        errno = EINVAL;
+        return NULL;
+    }
+
+    wchar_t * wbuf = GGML_MALLOC(wlen * sizeof(wchar_t));
+    wlen = MultiByteToWideChar(CP_UTF8, 0, mbs, -1, wbuf, wlen);
+    if (!wlen) {
+        GGML_FREE(wbuf);
+        errno = EINVAL;
+        return NULL;
+    }
+
+    return wbuf;
+}
+#endif
+
+FILE * ggml_fopen(const char * fname, const char * mode) {
+#ifdef _WIN32
+    FILE * file = NULL;
+
+    // convert fname (UTF-8)
+    wchar_t * wfname = ggml_mbstowcs(fname);
+    if (wfname) {
+        // convert mode (ANSI)
+        wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t));
+        wchar_t * wmode_p = wmode;
+        do {
+            *wmode_p++ = (wchar_t)*mode;
+        } while (*mode++);
+
+        // open file
+        file = _wfopen(wfname, wmode);
+
+        GGML_FREE(wfname);
+        GGML_FREE(wmode);
+    }
+
+    return file;
+#else
+    return fopen(fname, mode);
+#endif
+
+}
+static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
+static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
+
+static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
+    [GGML_TYPE_I8] = {
+        .type_name                = "i8",
+        .blck_size                = 1,
+        .type_size                = sizeof(int8_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I16] = {
+        .type_name                = "i16",
+        .blck_size                = 1,
+        .type_size                = sizeof(int16_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I32] = {
+        .type_name                = "i32",
+        .blck_size                = 1,
+        .type_size                = sizeof(int32_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_I64] = {
+        .type_name                = "i64",
+        .blck_size                = 1,
+        .type_size                = sizeof(int64_t),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_F64] = {
+        .type_name                = "f64",
+        .blck_size                = 1,
+        .type_size                = sizeof(double),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_F32] = {
+        .type_name                = "f32",
+        .blck_size                = 1,
+        .type_size                = sizeof(float),
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_F16] = {
+        .type_name                = "f16",
+        .blck_size                = 1,
+        .type_size                = sizeof(ggml_fp16_t),
+        .is_quantized             = false,
+        .to_float                 = (ggml_to_float_t) ggml_fp16_to_fp32_row,
+        .from_float_ref           = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+    },
+    [GGML_TYPE_Q4_0] = {
+        .type_name                = "q4_0",
+        .blck_size                = QK4_0,
+        .type_size                = sizeof(block_q4_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_0,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_0_ref,
+    },
+    [GGML_TYPE_Q4_1] = {
+        .type_name                = "q4_1",
+        .blck_size                = QK4_1,
+        .type_size                = sizeof(block_q4_1),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_1,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_1_ref,
+    },
+    [4] = { // GGML_TYPE_Q4_2
+        .type_name                = "DEPRECATED",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [5] = { // GGML_TYPE_Q4_3
+        .type_name                = "DEPRECATED",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_Q5_0] = {
+        .type_name                = "q5_0",
+        .blck_size                = QK5_0,
+        .type_size                = sizeof(block_q5_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_0,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_0_ref,
+    },
+    [GGML_TYPE_Q5_1] = {
+        .type_name                = "q5_1",
+        .blck_size                = QK5_1,
+        .type_size                = sizeof(block_q5_1),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_1,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_1_ref,
+    },
+    [GGML_TYPE_Q8_0] = {
+        .type_name                = "q8_0",
+        .blck_size                = QK8_0,
+        .type_size                = sizeof(block_q8_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q8_0,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q8_0_ref,
+    },
+    [GGML_TYPE_Q8_1] = {
+        .type_name                = "q8_1",
+        .blck_size                = QK8_1,
+        .type_size                = sizeof(block_q8_1),
+        .is_quantized             = true,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q8_1_ref,
+    },
+    [GGML_TYPE_Q2_K] = {
+        .type_name                = "q2_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q2_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q2_K,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q2_K_ref,
+    },
+    [GGML_TYPE_Q3_K] = {
+        .type_name                = "q3_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q3_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q3_K,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q3_K_ref,
+    },
+    [GGML_TYPE_Q4_K] = {
+        .type_name                = "q4_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q4_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q4_K,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q4_K_ref,
+    },
+    [GGML_TYPE_Q5_K] = {
+        .type_name                = "q5_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q5_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q5_K,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q5_K_ref,
+    },
+    [GGML_TYPE_Q6_K] = {
+        .type_name                = "q6_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q6_K),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_q6_K,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_q6_K_ref,
+    },
+    [GGML_TYPE_IQ2_XXS] = {
+        .type_name                = "iq2_xxs",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq2_xxs),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xxs,
+        .from_float_ref           = NULL,
+    },
+    [GGML_TYPE_IQ2_XS] = {
+        .type_name                = "iq2_xs",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq2_xs),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_xs,
+        .from_float_ref           = NULL,
+    },
+    [GGML_TYPE_IQ3_XXS] = {
+        .type_name                = "iq3_xxs",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq3_xxs),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_xxs,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
+    },
+    [GGML_TYPE_IQ3_S] = {
+        .type_name                = "iq3_s",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq3_s),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq3_s,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_iq3_s_ref,
+    },
+    [GGML_TYPE_IQ2_S] = {
+        .type_name                = "iq2_s",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq2_s),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq2_s,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_iq2_s_ref,
+    },
+    [GGML_TYPE_IQ1_S] = {
+        .type_name                = "iq1_s",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq1_s),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_s,
+        .from_float_ref           = NULL,
+    },
+    [GGML_TYPE_IQ1_M] = {
+        .type_name                = "iq1_m",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq1_m),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq1_m,
+        .from_float_ref           = NULL,
+    },
+    [GGML_TYPE_IQ4_NL] = {
+        .type_name                = "iq4_nl",
+        .blck_size                = QK4_NL,
+        .type_size                = sizeof(block_iq4_nl),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_nl,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_iq4_nl_ref,
+    },
+    [GGML_TYPE_IQ4_XS] = {
+        .type_name                = "iq4_xs",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_iq4_xs),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_iq4_xs,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_iq4_xs_ref,
+    },
+    [GGML_TYPE_Q8_K] = {
+        .type_name                = "q8_K",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_q8_K),
+        .is_quantized             = true,
+    },
+    [GGML_TYPE_BF16] = {
+        .type_name                = "bf16",
+        .blck_size                = 1,
+        .type_size                = sizeof(ggml_bf16_t),
+        .is_quantized             = false,
+        .to_float                 = (ggml_to_float_t) ggml_bf16_to_fp32_row,
+        .from_float_ref           = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref,
+    },
+    [31] = { // GGML_TYPE_Q4_0_4_4
+        .type_name                = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [32] = { // GGML_TYPE_Q4_0_4_8
+        .type_name                = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [33] = { // GGML_TYPE_Q4_0_8_8
+        .type_name                = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [GGML_TYPE_TQ1_0] = {
+        .type_name                = "tq1_0",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_tq1_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_tq1_0,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_tq1_0_ref,
+    },
+    [GGML_TYPE_TQ2_0] = {
+        .type_name                = "tq2_0",
+        .blck_size                = QK_K,
+        .type_size                = sizeof(block_tq2_0),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_tq2_0,
+        .from_float_ref           = (ggml_from_float_t) quantize_row_tq2_0_ref,
+    },
+    [36] = { // GGML_TYPE_IQ4_NL_4_4
+        .type_name                = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [37] = { // GGML_TYPE_IQ4_NL_4_8
+        .type_name                = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+    [38] = { // GGML_TYPE_IQ4_NL_8_8
+        .type_name                = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking",
+        .blck_size                = 0,
+        .type_size                = 0,
+        .is_quantized             = false,
+    },
+};
+
+const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
+    GGML_ASSERT(type < GGML_TYPE_COUNT);
+    return &type_traits[type];
+}
+
+//
+// ggml object
+//
+
+struct ggml_object {
+    size_t offs;
+    size_t size;
+
+    struct ggml_object * next;
+
+    enum ggml_object_type type;
+
+    char padding[4];
+};
+
+static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+//
+// ggml context
+//
+
+struct ggml_context {
+    size_t mem_size;
+    void * mem_buffer;
+    bool   mem_buffer_owned;
+    bool   no_alloc;
+
+    int    n_objects;
+
+    struct ggml_object * objects_begin;
+    struct ggml_object * objects_end;
+};
+
+struct ggml_context_container {
+    bool used;
+
+    struct ggml_context context;
+};
+
+//
+// data types
+//
+
+static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+    "NONE",
+
+    "DUP",
+    "ADD",
+    "ADD1",
+    "ACC",
+    "SUB",
+    "MUL",
+    "DIV",
+    "SQR",
+    "SQRT",
+    "LOG",
+    "SIN",
+    "COS",
+    "SUM",
+    "SUM_ROWS",
+    "MEAN",
+    "ARGMAX",
+    "COUNT_EQUAL",
+    "REPEAT",
+    "REPEAT_BACK",
+    "CONCAT",
+    "SILU_BACK",
+    "NORM",
+    "RMS_NORM",
+    "RMS_NORM_BACK",
+    "GROUP_NORM",
+
+    "MUL_MAT",
+    "MUL_MAT_ID",
+    "OUT_PROD",
+
+    "SCALE",
+    "SET",
+    "CPY",
+    "CONT",
+    "RESHAPE",
+    "VIEW",
+    "PERMUTE",
+    "TRANSPOSE",
+    "GET_ROWS",
+    "GET_ROWS_BACK",
+    "DIAG",
+    "DIAG_MASK_INF",
+    "DIAG_MASK_ZERO",
+    "SOFT_MAX",
+    "SOFT_MAX_BACK",
+    "ROPE",
+    "ROPE_BACK",
+    "CLAMP",
+    "CONV_TRANSPOSE_1D",
+    "IM2COL",
+    "IM2COL_BACK",
+    "CONV_TRANSPOSE_2D",
+    "POOL_1D",
+    "POOL_2D",
+    "POOL_2D_BACK",
+    "UPSCALE",
+    "PAD",
+    "PAD_REFLECT_1D",
+    "UNPAD",
+    "ARANGE",
+    "TIMESTEP_EMBEDDING",
+    "ARGSORT",
+    "LEAKY_RELU",
+
+    "FLASH_ATTN_EXT",
+    "FLASH_ATTN_BACK",
+    "SSM_CONV",
+    "SSM_SCAN",
+    "WIN_PART",
+    "WIN_UNPART",
+    "GET_REL_POS",
+    "ADD_REL_POS",
+    "RWKV_WKV6",
+
+    "UNARY",
+
+    "MAP_UNARY",
+    "MAP_BINARY",
+
+    "MAP_CUSTOM1_F32",
+    "MAP_CUSTOM2_F32",
+    "MAP_CUSTOM3_F32",
+
+    "MAP_CUSTOM1",
+    "MAP_CUSTOM2",
+    "MAP_CUSTOM3",
+
+    "CROSS_ENTROPY_LOSS",
+    "CROSS_ENTROPY_LOSS_BACK",
+    "OPT_STEP_ADAMW",
+};
+
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+
+static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+    "none",
+
+    "x",
+    "x+y",
+    "x+y",
+    "view(x,nb,offset)+=y->x",
+    "x-y",
+    "x*y",
+    "x/y",
+    "x^2",
+    "√x",
+    "log(x)",
+    "sin(x)",
+    "cos(x)",
+    "Σx",
+    "Σx_k",
+    "Σx/n",
+    "argmax(x)",
+    "count_equal(x)",
+    "repeat(x)",
+    "repeat_back(x)",
+    "concat(x, y)",
+    "silu_back(x)",
+    "norm(x)",
+    "rms_norm(x)",
+    "rms_norm_back(x)",
+    "group_norm(x)",
+
+    "X*Y",
+    "X[i]*Y",
+    "X*Y",
+
+    "x*v",
+    "y-\\>view(x)",
+    "x-\\>y",
+    "cont(x)",
+    "reshape(x)",
+    "view(x)",
+    "permute(x)",
+    "transpose(x)",
+    "get_rows(x)",
+    "get_rows_back(x)",
+    "diag(x)",
+    "diag_mask_inf(x)",
+    "diag_mask_zero(x)",
+    "soft_max(x)",
+    "soft_max_back(x)",
+    "rope(x)",
+    "rope_back(x)",
+    "clamp(x)",
+    "conv_transpose_1d(x)",
+    "im2col(x)",
+    "im2col_back(x)",
+    "conv_transpose_2d(x)",
+    "pool_1d(x)",
+    "pool_2d(x)",
+    "pool_2d_back(x)",
+    "upscale(x)",
+    "pad(x)",
+    "pad_reflect_1d(x)",
+    "unpad(x)",
+    "arange(start, stop, step)",
+    "timestep_embedding(timesteps, dim, max_period)",
+    "argsort(x)",
+    "leaky_relu(x)",
+
+    "flash_attn_ext(x)",
+    "flash_attn_back(x)",
+    "ssm_conv(x)",
+    "ssm_scan(x)",
+    "win_part(x)",
+    "win_unpart(x)",
+    "get_rel_pos(x)",
+    "add_rel_pos(x)",
+    "rwkv_wkv6(k, v, r, tf, td, s)",
+
+    "unary(x)",
+
+    "f(x)",
+    "f(x,y)",
+
+    "custom_f32(x)",
+    "custom_f32(x,y)",
+    "custom_f32(x,y,z)",
+
+    "custom(x)",
+    "custom(x,y)",
+    "custom(x,y,z)",
+
+    "cross_entropy_loss(x,y)",
+    "cross_entropy_loss_back(x,y)",
+    "adamw(x)",
+};
+
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+
+static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
+
+
+static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
+    "ABS",
+    "SGN",
+    "NEG",
+    "STEP",
+    "TANH",
+    "ELU",
+    "RELU",
+    "SIGMOID",
+    "GELU",
+    "GELU_QUICK",
+    "SILU",
+    "HARDSWISH",
+    "HARDSIGMOID",
+    "EXP",
+};
+
+static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
+
+
+static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
+static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
+
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_print_object(const struct ggml_object * obj) {
+    GGML_LOG_INFO(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
+            obj->type, obj->offs, obj->size, (const void *) obj->next);
+}
+
+void ggml_print_objects(const struct ggml_context * ctx) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx);
+
+    while (obj != NULL) {
+        ggml_print_object(obj);
+        obj = obj->next;
+    }
+
+    GGML_LOG_INFO("%s: --- end ---\n", __func__);
+}
+
+int64_t ggml_nelements(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+int64_t ggml_nrows(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
+}
+
+size_t ggml_nbytes(const struct ggml_tensor * tensor) {
+    size_t nbytes;
+    const size_t blck_size = ggml_blck_size(tensor->type);
+    if (blck_size == 1) {
+        nbytes = ggml_type_size(tensor->type);
+        for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+        }
+    }
+    else {
+        nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
+        for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+        }
+    }
+
+    return nbytes;
+}
+
+size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
+    return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
+}
+
+int64_t ggml_blck_size(enum ggml_type type) {
+    return type_traits[type].blck_size;
+}
+
+size_t ggml_type_size(enum ggml_type type) {
+    return type_traits[type].type_size;
+}
+
+size_t ggml_row_size(enum ggml_type type, int64_t ne) {
+    assert(ne % ggml_blck_size(type) == 0);
+    return ggml_type_size(type)*ne/ggml_blck_size(type);
+}
+
+double ggml_type_sizef(enum ggml_type type) {
+    return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
+}
+
+const char * ggml_type_name(enum ggml_type type) {
+    return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
+}
+
+bool ggml_is_quantized(enum ggml_type type) {
+    return type_traits[type].is_quantized;
+}
+
+const char * ggml_op_name(enum ggml_op op) {
+    return GGML_OP_NAME[op];
+}
+
+const char * ggml_op_symbol(enum ggml_op op) {
+    return GGML_OP_SYMBOL[op];
+}
+
+const char * ggml_unary_op_name(enum ggml_unary_op op) {
+    return GGML_UNARY_OP_NAME[op];
+}
+
+const char * ggml_op_desc(const struct ggml_tensor * t) {
+    if (t->op == GGML_OP_UNARY) {
+        enum ggml_unary_op uop = ggml_get_unary_op(t);
+        return ggml_unary_op_name(uop);
+    }
+    return ggml_op_name(t->op);
+}
+
+size_t ggml_element_size(const struct ggml_tensor * tensor) {
+    return ggml_type_size(tensor->type);
+}
+
+bool ggml_is_scalar(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[0] == 1 && tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_vector(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[1] == 1 && tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_matrix(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+bool ggml_is_3d(const struct ggml_tensor * tensor) {
+    return tensor->ne[3] == 1;
+}
+
+int ggml_n_dims(const struct ggml_tensor * tensor) {
+    for (int i = GGML_MAX_DIMS - 1; i >= 1; --i) {
+        if (tensor->ne[i] > 1) {
+            return i + 1;
+        }
+    }
+    return 1;
+}
+
+enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
+    enum ggml_type wtype = GGML_TYPE_COUNT;
+
+    switch (ftype) {
+        case GGML_FTYPE_ALL_F32:              wtype = GGML_TYPE_F32;   break;
+        case GGML_FTYPE_MOSTLY_F16:           wtype = GGML_TYPE_F16;   break;
+        case GGML_FTYPE_MOSTLY_BF16:          wtype = GGML_TYPE_BF16;  break;
+        case GGML_FTYPE_MOSTLY_Q4_0:          wtype = GGML_TYPE_Q4_0;  break;
+        case GGML_FTYPE_MOSTLY_Q4_1:          wtype = GGML_TYPE_Q4_1;  break;
+        case GGML_FTYPE_MOSTLY_Q5_0:          wtype = GGML_TYPE_Q5_0;  break;
+        case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
+        case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
+        case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
+        case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
+        case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
+        case GGML_FTYPE_MOSTLY_Q5_K:          wtype = GGML_TYPE_Q5_K;  break;
+        case GGML_FTYPE_MOSTLY_Q6_K:          wtype = GGML_TYPE_Q6_K;  break;
+        case GGML_FTYPE_MOSTLY_IQ2_XXS:       wtype = GGML_TYPE_IQ2_XXS;  break;
+        case GGML_FTYPE_MOSTLY_IQ2_XS:        wtype = GGML_TYPE_IQ2_XS;   break;
+        case GGML_FTYPE_MOSTLY_IQ3_XXS:       wtype = GGML_TYPE_IQ3_XXS;  break;
+        case GGML_FTYPE_MOSTLY_IQ1_S:         wtype = GGML_TYPE_IQ1_S;    break;
+        case GGML_FTYPE_MOSTLY_IQ1_M:         wtype = GGML_TYPE_IQ1_M;    break;
+        case GGML_FTYPE_MOSTLY_IQ4_NL:        wtype = GGML_TYPE_IQ4_NL;   break;
+        case GGML_FTYPE_MOSTLY_IQ4_XS:        wtype = GGML_TYPE_IQ4_XS;   break;
+        case GGML_FTYPE_MOSTLY_IQ3_S:         wtype = GGML_TYPE_IQ3_S;    break;
+        case GGML_FTYPE_MOSTLY_IQ2_S:         wtype = GGML_TYPE_IQ2_S;    break;
+        case GGML_FTYPE_UNKNOWN:              wtype = GGML_TYPE_COUNT; break;
+        case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
+    }
+
+    GGML_ASSERT(wtype != GGML_TYPE_COUNT);
+
+    return wtype;
+}
+
+size_t ggml_tensor_overhead(void) {
+    return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE;
+}
+
+bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+    return tensor->nb[0] > tensor->nb[1];
+}
+
+static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
+    size_t next_nb = ggml_type_size(tensor->type);
+    if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
+        return false;
+    }
+    next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        if (tensor->ne[i] != 1) {
+            if (i > n) {
+                if (tensor->nb[i] != next_nb) {
+                    return false;
+                }
+                next_nb *= tensor->ne[i];
+            } else {
+                // this dimension does not need to be contiguous
+                next_nb = tensor->ne[i]*tensor->nb[i];
+            }
+        }
+    }
+    return true;
+}
+
+bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
+    return ggml_is_contiguous_0(tensor);
+}
+
+bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
+    return ggml_is_contiguous_n(tensor, 0);
+}
+
+bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
+    return ggml_is_contiguous_n(tensor, 1);
+}
+
+bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
+    return ggml_is_contiguous_n(tensor, 2);
+}
+
+bool ggml_is_permuted(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
+}
+
+static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        tensor->nb[0] == ggml_type_size(tensor->type) &&
+        tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+        tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
+bool ggml_is_empty(const struct ggml_tensor * tensor) {
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        if (tensor->ne[i] == 0) {
+            // empty if any dimension has no elements
+            return true;
+        }
+    }
+    return false;
+}
+
+bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->ne[0] == t1->ne[0]) &&
+        (t0->ne[1] == t1->ne[1]) &&
+        (t0->ne[2] == t1->ne[2]) &&
+        (t0->ne[3] == t1->ne[3]);
+}
+
+bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return
+        (t0->nb[0] == t1->nb[0]) &&
+        (t0->nb[1] == t1->nb[1]) &&
+        (t0->nb[2] == t1->nb[2]) &&
+        (t0->nb[3] == t1->nb[3]);
+}
+
+// check if t1 can be represented as a repeatition of t0
+bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return ggml_is_empty(t0) ? ggml_is_empty(t1) :
+        (t1->ne[0]%t0->ne[0] == 0) &&
+        (t1->ne[1]%t0->ne[1] == 0) &&
+        (t1->ne[2]%t0->ne[2] == 0) &&
+        (t1->ne[3]%t0->ne[3] == 0);
+}
+
+static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
+}
+
+// assert that pointer is aligned to GGML_MEM_ALIGN
+#define GGML_ASSERT_ALIGNED(ptr) \
+    GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_context * ggml_init(struct ggml_init_params params) {
+    static bool is_first_call = true;
+
+    ggml_critical_section_start();
+
+    if (is_first_call) {
+        // initialize time system (required on Windows)
+        ggml_time_init();
+
+        for (int i = 0; i < (1 << 16); ++i) {
+            union {
+                uint16_t u16;
+                ggml_fp16_t fp16;
+            } u = {i};
+            ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
+        }
+
+        is_first_call = false;
+    }
+
+    ggml_critical_section_end();
+
+    struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context));
+
+    // allow to call ggml_init with 0 size
+    if (params.mem_size == 0) {
+        params.mem_size = GGML_MEM_ALIGN;
+    }
+
+    const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
+
+    *ctx = (struct ggml_context) {
+        /*.mem_size           =*/ mem_size,
+        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
+        /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
+        /*.no_alloc           =*/ params.no_alloc,
+        /*.n_objects          =*/ 0,
+        /*.objects_begin      =*/ NULL,
+        /*.objects_end        =*/ NULL,
+    };
+
+    GGML_ASSERT(ctx->mem_buffer != NULL);
+
+    GGML_ASSERT_ALIGNED(ctx->mem_buffer);
+
+    GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
+
+    return ctx;
+}
+
+void ggml_reset(struct ggml_context * ctx) {
+    if (ctx == NULL) {
+        return;
+    }
+
+    ctx->n_objects     = 0;
+    ctx->objects_begin = NULL;
+    ctx->objects_end   = NULL;
+}
+
+void ggml_free(struct ggml_context * ctx) {
+    if (ctx == NULL) {
+        return;
+    }
+
+    if (ctx->mem_buffer_owned) {
+        ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
+    }
+
+    GGML_FREE(ctx);
+}
+
+size_t ggml_used_mem(const struct ggml_context * ctx) {
+    return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size;
+}
+
+bool ggml_get_no_alloc(struct ggml_context * ctx) {
+    return ctx->no_alloc;
+}
+
+void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc) {
+    ctx->no_alloc = no_alloc;
+}
+
+void * ggml_get_mem_buffer(const struct ggml_context * ctx) {
+    return ctx->mem_buffer;
+}
+
+size_t ggml_get_mem_size(const struct ggml_context * ctx) {
+    return ctx->mem_size;
+}
+
+size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) {
+    size_t max_size = 0;
+
+    for (struct ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor != NULL; tensor = ggml_get_next_tensor(ctx, tensor)) {
+        size_t bytes = ggml_nbytes(tensor);
+        max_size = MAX(max_size, bytes);
+    }
+
+    return max_size;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) {
+    // always insert objects at the end of the context's memory pool
+    struct ggml_object * obj_cur = ctx->objects_end;
+
+    const size_t cur_offs = obj_cur == NULL ? 0 : obj_cur->offs;
+    const size_t cur_size = obj_cur == NULL ? 0 : obj_cur->size;
+    const size_t cur_end  = cur_offs + cur_size;
+
+    // align to GGML_MEM_ALIGN
+    size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN);
+
+    char * const mem_buffer = ctx->mem_buffer;
+    struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end);
+
+    if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) {
+        GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
+                __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size);
+#ifndef NDEBUG
+        GGML_ABORT("not enough space in the context's memory pool");
+#endif
+        return NULL;
+    }
+
+    *obj_new = (struct ggml_object) {
+        .offs = cur_end + GGML_OBJECT_SIZE,
+        .size = size_needed,
+        .next = NULL,
+        .type = type,
+    };
+
+    GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs);
+
+    if (obj_cur != NULL) {
+        obj_cur->next = obj_new;
+    } else {
+        // this is the first object in this context
+        ctx->objects_begin = obj_new;
+    }
+
+    ctx->objects_end = obj_new;
+
+    //printf("%s: inserted new object at %zu, size = %zu\n", __func__, cur_end, obj_new->size);
+
+    return obj_new;
+}
+
+static struct ggml_tensor * ggml_new_tensor_impl(
+        struct ggml_context * ctx,
+        enum   ggml_type      type,
+        int                   n_dims,
+        const int64_t       * ne,
+        struct ggml_tensor  * view_src,
+        size_t                view_offs) {
+
+    GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);
+    GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
+
+    // find the base tensor and absolute offset
+    if (view_src != NULL && view_src->view_src != NULL) {
+        view_offs += view_src->view_offs;
+        view_src   = view_src->view_src;
+    }
+
+    size_t data_size = ggml_row_size(type, ne[0]);
+    for (int i = 1; i < n_dims; i++) {
+        data_size *= ne[i];
+    }
+
+    GGML_ASSERT(view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src));
+
+    void * data = view_src != NULL ? view_src->data : NULL;
+    if (data != NULL) {
+        data = (char *) data + view_offs;
+    }
+
+    size_t obj_alloc_size = 0;
+
+    if (view_src == NULL && !ctx->no_alloc) {
+        // allocate tensor data in the context's memory pool
+        obj_alloc_size = data_size;
+    }
+
+    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
+    GGML_ASSERT(obj_new);
+
+    struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
+
+#ifdef __clang__
+    // temporary until ggml_tensor::backend is removed
+    #pragma clang diagnostic push
+    #pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
+    *result = (struct ggml_tensor) {
+        /*.type         =*/ type,
+        /*.backend      =*/ GGML_BACKEND_TYPE_CPU,
+        /*.buffer       =*/ NULL,
+        /*.ne           =*/ { 1, 1, 1, 1 },
+        /*.nb           =*/ { 0, 0, 0, 0 },
+        /*.op           =*/ GGML_OP_NONE,
+        /*.op_params    =*/ { 0 },
+        /*.flags        =*/ 0,
+        /*.src          =*/ { NULL },
+        /*.view_src     =*/ view_src,
+        /*.view_offs    =*/ view_offs,
+        /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
+        /*.name         =*/ { 0 },
+        /*.extra        =*/ NULL,
+        /*.padding      =*/ { 0 },
+    };
+
+#ifdef __clang__
+    #pragma clang diagnostic pop
+#endif
+
+    // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
+    //GGML_ASSERT_ALIGNED(result->data);
+
+    for (int i = 0; i < n_dims; i++) {
+        result->ne[i] = ne[i];
+    }
+
+    result->nb[0] = ggml_type_size(type);
+    result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
+    }
+
+    ctx->n_objects++;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_new_tensor(
+        struct ggml_context * ctx,
+        enum   ggml_type      type,
+        int                   n_dims,
+        const int64_t       * ne) {
+    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
+}
+
+struct ggml_tensor * ggml_new_tensor_1d(
+        struct ggml_context * ctx,
+        enum   ggml_type      type,
+        int64_t ne0) {
+    return ggml_new_tensor(ctx, type, 1, &ne0);
+}
+
+struct ggml_tensor * ggml_new_tensor_2d(
+        struct ggml_context * ctx,
+        enum   ggml_type      type,
+        int64_t ne0,
+        int64_t ne1) {
+    const int64_t ne[2] = { ne0, ne1 };
+    return ggml_new_tensor(ctx, type, 2, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_3d(
+        struct ggml_context * ctx,
+        enum   ggml_type      type,
+        int64_t ne0,
+        int64_t ne1,
+        int64_t ne2) {
+    const int64_t ne[3] = { ne0, ne1, ne2 };
+    return ggml_new_tensor(ctx, type, 3, ne);
+}
+
+struct ggml_tensor * ggml_new_tensor_4d(
+        struct ggml_context * ctx,
+        enum   ggml_type type,
+        int64_t ne0,
+        int64_t ne1,
+        int64_t ne2,
+        int64_t ne3) {
+    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+    return ggml_new_tensor(ctx, type, 4, ne);
+}
+
+void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) {
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes);
+
+    return (uint8_t *)ctx->mem_buffer + obj->offs;
+}
+
+struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
+    return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne);
+}
+
+void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) {
+    const int64_t ne2 = tensor->ne[2];
+    const int64_t ne1 = tensor->ne[1];
+    const int64_t ne0 = tensor->ne[0];
+
+    const int64_t i3_ = (i/(ne2*ne1*ne0));
+    const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0);
+    const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0;
+    const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0);
+
+    if (i0) {
+        * i0 = i0_;
+    }
+    if (i1) {
+        * i1 = i1_;
+    }
+    if (i2) {
+        * i2 = i2_;
+    }
+    if (i3) {
+        * i3 = i3_;
+    }
+}
+
+void * ggml_get_data(const struct ggml_tensor * tensor) {
+    return tensor->data;
+}
+
+float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
+    assert(tensor->type == GGML_TYPE_F32);
+    return (float *)(tensor->data);
+}
+
+enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->op == GGML_OP_UNARY);
+    return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
+}
+
+const char * ggml_get_name(const struct ggml_tensor * tensor) {
+    return tensor->name;
+}
+
+struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) {
+    size_t i;
+    for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) {
+        tensor->name[i] = name[i];
+    }
+    tensor->name[i] = '\0';
+    return tensor;
+}
+
+struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
+    va_list args;
+    va_start(args, fmt);
+    vsnprintf(tensor->name, sizeof(tensor->name), fmt, args);
+    va_end(args);
+    return tensor;
+}
+
+struct ggml_tensor * ggml_view_tensor(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * src) {
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, GGML_MAX_DIMS, src->ne, src, 0);
+    ggml_format_name(result, "%s (view)", src->name);
+
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = src->nb[i];
+    }
+
+    return result;
+}
+
+struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    while (obj != NULL) {
+        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
+            return (struct ggml_tensor *)(mem_buffer + obj->offs);
+        }
+
+        obj = obj->next;
+    }
+
+    return NULL;
+}
+
+struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struct ggml_tensor * tensor) {
+    struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
+    obj = obj->next;
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    while (obj != NULL) {
+        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
+            return (struct ggml_tensor *)(mem_buffer + obj->offs);
+        }
+
+        obj = obj->next;
+    }
+
+    return NULL;
+}
+
+struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    while (obj != NULL) {
+        if (obj->type == GGML_OBJECT_TYPE_TENSOR) {
+            struct ggml_tensor * cur = (struct ggml_tensor *)(mem_buffer + obj->offs);
+            if (strcmp(cur->name, name) == 0) {
+                return cur;
+            }
+        }
+
+        obj = obj->next;
+    }
+
+    return NULL;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// ggml_dup
+
+static struct ggml_tensor * ggml_dup_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_DUP;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_dup(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_dup_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_dup_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_dup_impl(ctx, a, true);
+}
+
+// ggml_add
+
+static struct ggml_tensor * ggml_add_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_ADD;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_add_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_add_impl(ctx, a, b, true);
+}
+
+// ggml_add_cast
+
+static struct ggml_tensor * ggml_add_cast_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        enum   ggml_type      type) {
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
+
+    // currently only supported for quantized input and f16
+    GGML_ASSERT(ggml_is_quantized(a->type) ||
+                a->type == GGML_TYPE_F16 ||
+                a->type == GGML_TYPE_BF16);
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
+
+    result->op     = GGML_OP_ADD;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add_cast(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        enum   ggml_type      type) {
+    return ggml_add_cast_impl(ctx, a, b, type);
+}
+
+// ggml_add1
+
+static struct ggml_tensor * ggml_add1_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_is_scalar(b));
+    GGML_ASSERT(ggml_is_padded_1d(a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_ADD1;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add1(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_add1_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_add1_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_add1_impl(ctx, a, b, true);
+}
+
+// ggml_acc
+
+static struct ggml_tensor * ggml_acc_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(a->type == GGML_TYPE_F32);
+    GGML_ASSERT(b->type == GGML_TYPE_F32);
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_ACC;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_acc(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+}
+
+struct ggml_tensor * ggml_acc_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+}
+
+// ggml_sub
+
+static struct ggml_tensor * ggml_sub_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_SUB;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sub(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_sub_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_sub_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_sub_impl(ctx, a, b, true);
+}
+
+// ggml_mul
+
+static struct ggml_tensor * ggml_mul_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_MUL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_mul(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_mul_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_mul_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_mul_impl(ctx, a, b, true);
+}
+
+// ggml_div
+
+static struct ggml_tensor * ggml_div_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_DIV;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_div(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_div_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_div_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_div_impl(ctx, a, b, true);
+}
+
+// ggml_sqr
+
+static struct ggml_tensor * ggml_sqr_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_SQR;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sqr(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqr_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqr_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqr_impl(ctx, a, true);
+}
+
+// ggml_sqrt
+
+static struct ggml_tensor * ggml_sqrt_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_SQRT;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sqrt(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqrt_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sqrt_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sqrt_impl(ctx, a, true);
+}
+
+// ggml_log
+
+static struct ggml_tensor * ggml_log_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_LOG;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_log(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_log_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_log_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_log_impl(ctx, a, true);
+}
+
+// ggml_sin
+
+static struct ggml_tensor * ggml_sin_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_SIN;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sin(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sin_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sin_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sin_impl(ctx, a, true);
+}
+
+// ggml_cos
+
+static struct ggml_tensor * ggml_cos_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_COS;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_cos_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_cos_impl(ctx, a, true);
+}
+
+// ggml_sum
+
+struct ggml_tensor * ggml_sum(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+    result->op     = GGML_OP_SUM;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_sum_rows
+
+struct ggml_tensor * ggml_sum_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    int64_t ne[GGML_MAX_DIMS] = { 1 };
+    for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+        ne[i] = a->ne[i];
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+
+    result->op     = GGML_OP_SUM_ROWS;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_mean
+
+struct ggml_tensor * ggml_mean(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_MEAN;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_argmax
+
+struct ggml_tensor * ggml_argmax(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    GGML_ASSERT(ggml_is_matrix(a));
+    GGML_ASSERT(a->ne[0] <= INT32_MAX);
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]);
+
+    result->op     = GGML_OP_ARGMAX;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_count_equal
+
+struct ggml_tensor * ggml_count_equal(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
+
+    result->op     = GGML_OP_COUNT_EQUAL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_repeat
+
+struct ggml_tensor * ggml_repeat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_repeat(a, b));
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
+
+    result->op     = GGML_OP_REPEAT;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_repeat_back
+
+struct ggml_tensor * ggml_repeat_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_repeat(b, a));
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne);
+
+    result->op     = GGML_OP_REPEAT_BACK;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_concat
+
+struct ggml_tensor * ggml_concat(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    struct ggml_tensor  * b,
+    int                   dim) {
+    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+
+    int64_t ne[GGML_MAX_DIMS];
+    for (int d = 0; d < GGML_MAX_DIMS; ++d) {
+        if (d == dim) {
+            ne[d] = a->ne[d] + b->ne[d];
+            continue;
+        }
+        GGML_ASSERT(a->ne[d] == b->ne[d]);
+        ne[d] = a->ne[d];
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne);
+
+    ggml_set_op_params_i32(result, 0, dim);
+
+    result->op     = GGML_OP_CONCAT;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_abs
+
+struct ggml_tensor * ggml_abs(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_ABS);
+}
+
+struct ggml_tensor * ggml_abs_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ABS);
+}
+
+// ggml_sgn
+
+struct ggml_tensor * ggml_sgn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SGN);
+}
+
+struct ggml_tensor * ggml_sgn_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SGN);
+}
+
+// ggml_neg
+
+struct ggml_tensor * ggml_neg(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_NEG);
+}
+
+struct ggml_tensor * ggml_neg_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_NEG);
+}
+
+// ggml_step
+
+struct ggml_tensor * ggml_step(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_STEP);
+}
+
+struct ggml_tensor * ggml_step_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_STEP);
+}
+
+// ggml_tanh
+
+struct ggml_tensor * ggml_tanh(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_TANH);
+}
+
+struct ggml_tensor * ggml_tanh_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_TANH);
+}
+
+// ggml_elu
+
+struct ggml_tensor * ggml_elu(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_ELU);
+}
+
+struct ggml_tensor * ggml_elu_inplace(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU);
+}
+
+// ggml_relu
+
+struct ggml_tensor * ggml_relu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_RELU);
+}
+
+struct ggml_tensor * ggml_relu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
+}
+
+// ggml_leaky_relu
+
+struct ggml_tensor * ggml_leaky_relu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 negative_slope,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
+
+    result->op     = GGML_OP_LEAKY_RELU;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_sigmoid
+
+struct ggml_tensor * ggml_sigmoid(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
+struct ggml_tensor * ggml_sigmoid_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
+// ggml_gelu
+
+struct ggml_tensor * ggml_gelu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU);
+}
+
+struct ggml_tensor * ggml_gelu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
+}
+
+// ggml_gelu_quick
+
+struct ggml_tensor * ggml_gelu_quick(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_QUICK);
+}
+
+struct ggml_tensor * ggml_gelu_quick_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_QUICK);
+}
+
+// ggml_silu
+
+struct ggml_tensor * ggml_silu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SILU);
+}
+
+struct ggml_tensor * ggml_silu_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
+}
+
+// ggml_silu_back
+
+struct ggml_tensor * ggml_silu_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_SILU_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml hardswish
+
+struct ggml_tensor * ggml_hardswish(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSWISH);
+}
+
+// ggml hardsigmoid
+
+struct ggml_tensor * ggml_hardsigmoid(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_HARDSIGMOID);
+}
+
+// ggml exp
+
+struct ggml_tensor * ggml_exp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_EXP);
+}
+
+struct ggml_tensor * ggml_exp_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
+}
+
+// ggml_norm
+
+static struct ggml_tensor * ggml_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, &eps, sizeof(eps));
+
+    result->op     = GGML_OP_NORM;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_norm_impl(ctx, a, eps, true);
+}
+
+// ggml_rms_norm
+
+static struct ggml_tensor * ggml_rms_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, &eps, sizeof(eps));
+
+    result->op     = GGML_OP_RMS_NORM;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_rms_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_rms_norm_impl(ctx, a, eps, false);
+}
+
+struct ggml_tensor * ggml_rms_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 eps) {
+    return ggml_rms_norm_impl(ctx, a, eps, true);
+}
+
+// ggml_rms_norm_back
+
+struct ggml_tensor * ggml_rms_norm_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        float                 eps) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, &eps, sizeof(eps));
+
+    result->op     = GGML_OP_RMS_NORM_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_group_norm
+
+static struct ggml_tensor * ggml_group_norm_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_groups,
+        float                 eps,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params_i32(result, 0, n_groups);
+    ggml_set_op_params_f32(result, 1, eps);
+
+    result->op     = GGML_OP_GROUP_NORM;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_group_norm(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_groups,
+        float                 eps) {
+    return ggml_group_norm_impl(ctx, a, n_groups, eps, false);
+}
+
+struct ggml_tensor * ggml_group_norm_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_groups,
+        float                 eps) {
+    return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
+}
+
+// ggml_mul_mat
+
+static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (t0->ne[0]           == t1->ne[0])  &&
+           (t1->ne[2]%t0->ne[2] == 0)          && // verify t0 is broadcastable
+           (t1->ne[3]%t0->ne[3] == 0);
+}
+
+struct ggml_tensor * ggml_mul_mat(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_mul_mat(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
+
+    const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_MUL_MAT;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+void ggml_mul_mat_set_prec(
+        struct ggml_tensor * a,
+        enum ggml_prec       prec) {
+    GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
+
+    const int32_t prec_i32 = (int32_t) prec;
+
+    ggml_set_op_params_i32(a, 0, prec_i32);
+}
+
+// ggml_mul_mat_id
+
+/*
+    c = ggml_mul_mat_id(ctx, as, b, ids);
+
+    as  -> [cols, rows, n_expert]
+    ids -> [n_experts_used, n_tokens] (i32)
+    b   -> [cols, n_expert_used, n_tokens]
+    c   -> [rows, n_expert_used, n_tokens]
+
+    in b, n_experts_used can be broadcasted to match the n_expert_used of ids
+
+    c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
+*/
+struct ggml_tensor * ggml_mul_mat_id(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * as,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * ids) {
+    GGML_ASSERT(!ggml_is_transposed(as));
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
+    GGML_ASSERT(b->ne[3] == 1); // b is 3d
+    GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
+    GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
+    GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
+    GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
+
+    const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_MUL_MAT_ID;
+    result->src[0] = as;
+    result->src[1] = b;
+    result->src[2] = ids;
+
+    return result;
+}
+
+// ggml_out_prod
+
+static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (t0->ne[1] == t1->ne[1])   &&
+           (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
+           (t1->ne[3]%t0->ne[3] == 0);
+}
+
+struct ggml_tensor * ggml_out_prod(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_can_out_prod(a, b));
+    GGML_ASSERT(!ggml_is_transposed(a));
+
+    // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
+    const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_OUT_PROD;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_scale
+
+static struct ggml_tensor * ggml_scale_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_is_padded_1d(a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, &s, sizeof(s));
+
+    result->op     = GGML_OP_SCALE;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_scale(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s) {
+    return ggml_scale_impl(ctx, a, s, false);
+}
+
+struct ggml_tensor * ggml_scale_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s) {
+    return ggml_scale_impl(ctx, a, s, true);
+}
+
+// ggml_set
+
+static struct ggml_tensor * ggml_set_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b));
+
+    // make a view of the destination
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    GGML_ASSERT(offset < (size_t)(1 << 30));
+    int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_SET;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_set(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
+}
+
+struct ggml_tensor * ggml_set_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
+}
+
+struct ggml_tensor * ggml_set_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
+}
+
+struct ggml_tensor * ggml_set_1d_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
+}
+
+struct ggml_tensor * ggml_set_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
+}
+
+struct ggml_tensor * ggml_set_2d_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        size_t                nb1,
+        size_t                offset) {
+    return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
+}
+
+// ggml_cpy
+
+static struct ggml_tensor * ggml_cpy_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+    // make a view of the destination
+    struct ggml_tensor * result = ggml_view_tensor(ctx, b);
+    if (strlen(b->name) > 0) {
+        ggml_format_name(result, "%s (copy of %s)", b->name, a->name);
+    } else {
+        ggml_format_name(result, "%s (copy)", a->name);
+    }
+
+    result->op     = GGML_OP_CPY;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cpy(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    return ggml_cpy_impl(ctx, a, b);
+}
+
+struct ggml_tensor * ggml_cast(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum   ggml_type      type) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
+    ggml_format_name(result, "%s (copy)", a->name);
+
+    result->op     = GGML_OP_CPY;
+    result->src[0] = a;
+    result->src[1] = result;
+
+    return result;
+}
+
+// ggml_cont
+
+static struct ggml_tensor * ggml_cont_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+    ggml_format_name(result, "%s (cont)", a->name);
+
+    result->op     = GGML_OP_CONT;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cont(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a) {
+    return ggml_cont_impl(ctx, a);
+}
+
+// make contiguous, with new shape
+GGML_API struct ggml_tensor * ggml_cont_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0) {
+    return ggml_cont_4d(ctx, a, ne0, 1, 1, 1);
+}
+
+GGML_API struct ggml_tensor * ggml_cont_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1) {
+    return ggml_cont_4d(ctx, a, ne0, ne1, 1, 1);
+}
+
+GGML_API struct ggml_tensor * ggml_cont_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2) {
+    return ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1);
+}
+
+struct ggml_tensor * ggml_cont_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3) {
+    GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3));
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
+    ggml_format_name(result, "%s (cont)", a->name);
+
+    result->op     = GGML_OP_CONT;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_reshape
+
+struct ggml_tensor * ggml_reshape(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        struct ggml_tensor * b) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
+    GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0);
+    ggml_format_name(result, "%s (reshaped)", a->name);
+
+    result->op     = GGML_OP_RESHAPE;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0);
+
+    const int64_t ne[1] = { ne0 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
+    ggml_format_name(result, "%s (reshaped)", a->name);
+
+    result->op     = GGML_OP_RESHAPE;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1);
+
+    const int64_t ne[2] = { ne0, ne1 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
+    ggml_format_name(result, "%s (reshaped)", a->name);
+
+    result->op     = GGML_OP_RESHAPE;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2);
+
+    const int64_t ne[3] = { ne0, ne1, ne2 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
+    ggml_format_name(result, "%s (reshaped)", a->name);
+
+    result->op     = GGML_OP_RESHAPE;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_reshape_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
+
+    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
+    ggml_format_name(result, "%s (reshaped)", a->name);
+
+    result->op     = GGML_OP_RESHAPE;
+    result->src[0] = a;
+
+    return result;
+}
+
+static struct ggml_tensor * ggml_view_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_dims,
+        const int64_t       * ne,
+        size_t                offset) {
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
+    ggml_format_name(result, "%s (view)", a->name);
+
+    ggml_set_op_params(result, &offset, sizeof(offset));
+
+    result->op     = GGML_OP_VIEW;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_view_1d
+
+struct ggml_tensor * ggml_view_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        size_t                offset) {
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
+
+    return result;
+}
+
+// ggml_view_2d
+
+struct ggml_tensor * ggml_view_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        size_t                nb1,
+        size_t                offset) {
+    const int64_t ne[2] = { ne0, ne1 };
+
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = result->nb[1]*ne1;
+    result->nb[3] = result->nb[2];
+
+    return result;
+}
+
+// ggml_view_3d
+
+struct ggml_tensor * ggml_view_3d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                offset) {
+    const int64_t ne[3] = { ne0, ne1, ne2 };
+
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = nb2;
+    result->nb[3] = result->nb[2]*ne2;
+
+    return result;
+}
+
+// ggml_view_4d
+
+struct ggml_tensor * ggml_view_4d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int64_t               ne0,
+        int64_t               ne1,
+        int64_t               ne2,
+        int64_t               ne3,
+        size_t                nb1,
+        size_t                nb2,
+        size_t                nb3,
+        size_t                offset) {
+    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
+
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
+
+    result->nb[1] = nb1;
+    result->nb[2] = nb2;
+    result->nb[3] = nb3;
+
+    return result;
+}
+
+// ggml_permute
+
+struct ggml_tensor * ggml_permute(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   axis0,
+        int                   axis1,
+        int                   axis2,
+        int                   axis3) {
+    GGML_ASSERT(axis0 >= 0 && axis0 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis1 >= 0 && axis1 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis2 >= 0 && axis2 < GGML_MAX_DIMS);
+    GGML_ASSERT(axis3 >= 0 && axis3 < GGML_MAX_DIMS);
+
+    GGML_ASSERT(axis0 != axis1);
+    GGML_ASSERT(axis0 != axis2);
+    GGML_ASSERT(axis0 != axis3);
+    GGML_ASSERT(axis1 != axis2);
+    GGML_ASSERT(axis1 != axis3);
+    GGML_ASSERT(axis2 != axis3);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    ggml_format_name(result, "%s (permuted)", a->name);
+
+    int ne[GGML_MAX_DIMS];
+    int nb[GGML_MAX_DIMS];
+
+    ne[axis0] = a->ne[0];
+    ne[axis1] = a->ne[1];
+    ne[axis2] = a->ne[2];
+    ne[axis3] = a->ne[3];
+
+    nb[axis0] = a->nb[0];
+    nb[axis1] = a->nb[1];
+    nb[axis2] = a->nb[2];
+    nb[axis3] = a->nb[3];
+
+    result->ne[0] = ne[0];
+    result->ne[1] = ne[1];
+    result->ne[2] = ne[2];
+    result->ne[3] = ne[3];
+
+    result->nb[0] = nb[0];
+    result->nb[1] = nb[1];
+    result->nb[2] = nb[2];
+    result->nb[3] = nb[3];
+
+    result->op     = GGML_OP_PERMUTE;
+    result->src[0] = a;
+
+    int32_t params[] = { axis0, axis1, axis2, axis3 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    return result;
+}
+
+// ggml_transpose
+
+struct ggml_tensor * ggml_transpose(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+    ggml_format_name(result, "%s (transposed)", a->name);
+
+    result->ne[0] = a->ne[1];
+    result->ne[1] = a->ne[0];
+
+    result->nb[0] = a->nb[1];
+    result->nb[1] = a->nb[0];
+
+    result->op     = GGML_OP_TRANSPOSE;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_get_rows
+
+struct ggml_tensor * ggml_get_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(b->ne[3] == 1);
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+
+    // TODO: implement non F32 return
+    enum ggml_type type = GGML_TYPE_F32;
+    if (a->type == GGML_TYPE_I32) {
+        type = a->type;
+    }
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
+
+    result->op     = GGML_OP_GET_ROWS;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_get_rows_back
+
+struct ggml_tensor * ggml_get_rows_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
+
+    // TODO: implement non F32 return
+    //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
+    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]);
+
+    result->op     = GGML_OP_GET_ROWS_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_diag
+
+struct ggml_tensor * ggml_diag(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    GGML_ASSERT(a->ne[1] == 1);
+
+    const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne);
+
+    result->op     = GGML_OP_DIAG;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_diag_mask_inf
+
+static struct ggml_tensor * ggml_diag_mask_inf_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    int32_t params[] = { n_past };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_DIAG_MASK_INF;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_diag_mask_inf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
+}
+
+struct ggml_tensor * ggml_diag_mask_inf_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_inf_impl(ctx, a, n_past, true);
+}
+
+// ggml_diag_mask_zero
+
+static struct ggml_tensor * ggml_diag_mask_zero_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    int32_t params[] = { n_past };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_DIAG_MASK_ZERO;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_diag_mask_zero(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_zero_impl(ctx, a, n_past, false);
+}
+
+struct ggml_tensor * ggml_diag_mask_zero_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   n_past) {
+    return ggml_diag_mask_zero_impl(ctx, a, n_past, true);
+}
+
+// ggml_soft_max
+
+static struct ggml_tensor * ggml_soft_max_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale,
+        float                 max_bias,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+
+    if (mask) {
+        GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(ggml_is_matrix(mask));
+        GGML_ASSERT(mask->ne[0] == a->ne[0]);
+        GGML_ASSERT(mask->ne[1] >= a->ne[1]);
+    }
+
+    if (max_bias > 0.0f) {
+        GGML_ASSERT(mask);
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    float params[] = { scale, max_bias };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_SOFT_MAX;
+    result->src[0] = a;
+    result->src[1] = mask;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_soft_max(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
+}
+
+struct ggml_tensor * ggml_soft_max_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
+}
+
+// ggml_soft_max_back
+
+static struct ggml_tensor * ggml_soft_max_back_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        bool                  inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_SOFT_MAX_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_soft_max_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_soft_max_back_impl(ctx, a, b, false);
+}
+
+struct ggml_tensor * ggml_soft_max_back_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_soft_max_back_impl(ctx, a, b, true);
+}
+
+// ggml_rope
+
+static struct ggml_tensor * ggml_rope_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow,
+        bool                  inplace) {
+    GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
+    GGML_ASSERT(ggml_is_vector(b));
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[0]);
+
+    if (c) {
+        GGML_ASSERT(c->type == GGML_TYPE_F32);
+        GGML_ASSERT(c->ne[0] >= n_dims / 2);
+    }
+
+    int sections[4] = {0, 0, 0, 0};
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(params + 11, §ions,     sizeof(int)*4);
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_ROPE;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_rope(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   n_dims,
+        int                   mode) {
+    return ggml_rope_impl(
+        ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
+    );
+}
+
+struct ggml_tensor * ggml_rope_multi(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   sections[4],
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    // Multimodal Rotary Position Embedding
+    GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
+
+    GGML_ASSERT(ggml_is_vector(b));
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
+
+    if (c) {
+        GGML_ASSERT(c->type == GGML_TYPE_F32);
+        GGML_ASSERT(c->ne[0] >= n_dims / 2);
+    }
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    memcpy(¶ms[11], sections,      sizeof(int)*4);
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op   = GGML_OP_ROPE;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_rope_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   n_dims,
+        int                   mode) {
+    return ggml_rope_impl(
+        ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
+    );
+}
+
+struct ggml_tensor * ggml_rope_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, false
+    );
+}
+
+struct ggml_tensor * ggml_rope_ext_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, true
+    );
+}
+
+struct ggml_tensor * ggml_rope_custom(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, false
+    );
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    return ggml_rope_impl(
+        ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
+        ext_factor, attn_factor, beta_fast, beta_slow, true
+    );
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base));
+}
+
+void ggml_rope_yarn_corr_dims(
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
+) {
+    // start and end correction dims
+    float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base));
+    float end   =  ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base));
+    dims[0] = MAX(0, start);
+    dims[1] = MIN(n_dims - 1, end);
+}
+
+// ggml_rope_back
+
+struct ggml_tensor * ggml_rope_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    GGML_ASSERT(ggml_is_vector(b));
+    GGML_ASSERT(b->type == GGML_TYPE_I32);
+    GGML_ASSERT(a->ne[2] == b->ne[0]);
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
+    memcpy(params +  5, &freq_base,    sizeof(float));
+    memcpy(params +  6, &freq_scale,   sizeof(float));
+    memcpy(params +  7, &ext_factor,   sizeof(float));
+    memcpy(params +  8, &attn_factor,  sizeof(float));
+    memcpy(params +  9, &beta_fast,    sizeof(float));
+    memcpy(params + 10, &beta_slow,    sizeof(float));
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_ROPE_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
+// ggml_clamp
+
+struct ggml_tensor * ggml_clamp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 min,
+        float                 max) {
+    // TODO: when implement backward, fix this:
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    float params[] = { min, max };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_CLAMP;
+    result->src[0] = a;
+
+    return result;
+}
+
+static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+    return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
+}
+
+// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OH, OW, IC*KH*KW]
+struct ggml_tensor * ggml_im2col(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1,
+        bool                  is_2D,
+        enum ggml_type        dst_type) {
+    if (is_2D) {
+        GGML_ASSERT(a->ne[2] == b->ne[2]);
+    } else {
+        //GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
+        GGML_ASSERT(b->ne[1] == a->ne[1]);
+        GGML_ASSERT(b->ne[3] == 1);
+    }
+
+    const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
+    const int64_t OW =         ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
+
+    GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a");
+    GGML_ASSERT((OW > 0)           && "b too small compared to a");
+
+    const int64_t ne[4] = {
+        is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
+        OW,
+        is_2D ? OH : b->ne[2],
+        is_2D ?      b->ne[3] : 1,
+    };
+
+    struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_IM2COL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_im2col_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int64_t             * ne,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1,
+        bool                  is_2D) {
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+    int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_IM2COL_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_conv_1d
+
+struct ggml_tensor * ggml_conv_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   p0,
+        int                   d0) {
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
+
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2]));                    // [OC,IC, K] => [OC, IC * K]
+
+    result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
+
+    return result;
+}
+
+// ggml_conv_1d_ph
+
+struct ggml_tensor* ggml_conv_1d_ph(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s,
+        int                   d) {
+    return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
+}
+
+// ggml_conv_1d_dw
+
+struct ggml_tensor * ggml_conv_1d_dw(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   p0,
+        int                   d0) {
+    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
+    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
+
+    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
+
+    struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
+
+    result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
+
+    return result;
+}
+
+// ggml_conv_1d_dw_ph
+
+struct ggml_tensor * ggml_conv_1d_dw_ph(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   d0) {
+    return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
+}
+
+// ggml_conv_transpose_1d
+
+static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
+    return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
+}
+
+GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   p0,
+        int                   d0) {
+    GGML_ASSERT(ggml_is_matrix(b));
+    GGML_ASSERT(a->ne[2] == b->ne[1]);
+    GGML_ASSERT(a->ne[3] == 1);
+
+    GGML_ASSERT(p0 == 0);
+    GGML_ASSERT(d0 == 1);
+
+    const int64_t ne[4] = {
+        ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
+        a->ne[1], b->ne[2], 1,
+    };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    int32_t params[] = { s0, p0, d0 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_CONV_TRANSPOSE_1D;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_conv_2d
+
+// a: [OC,IC, KH, KW]
+// b: [N, IC, IH, IW]
+// result: [N, OC, OH, OW]
+struct ggml_tensor * ggml_conv_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1) {
+    struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
+
+    struct ggml_tensor * result =
+        ggml_mul_mat(ctx,
+                ggml_reshape_2d(ctx, im2col, im2col->ne[0],  im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
+                ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]),  a->ne[3]));                       // [OC,IC, KH, KW] => [OC, IC * KH * KW]
+
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], im2col->ne[3], a->ne[3]); // [OC, N, OH, OW]
+    result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OH, OW]
+
+
+    return result;
+}
+
+// ggml_conv_2d_sk_p0
+
+struct ggml_tensor * ggml_conv_2d_sk_p0(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_conv_2d(ctx, a, b, a->ne[0], a->ne[1], 0, 0, 1, 1);
+}
+
+// ggml_conv_2d_s1_ph
+
+struct ggml_tensor * ggml_conv_2d_s1_ph(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
+}
+
+// ggml_conv_2d_dw
+
+struct ggml_tensor * ggml_conv_2d_dw(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   s0,
+        int                   s1,
+        int                   p0,
+        int                   p1,
+        int                   d0,
+        int                   d1) {
+    struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
+    struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
+                                        ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
+                                        s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
+    struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
+
+    new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2],  new_a->ne[3], 1);                       // [OC,1, KH, KW] => [1, OC, 1, KH * KW]
+    struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
+    result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
+
+    return result;
+}
+
+// ggml_conv_transpose_2d_p0
+
+static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
+    return (ins - 1) * s - 2 * p + ks;
+}
+
+struct ggml_tensor * ggml_conv_transpose_2d_p0(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        int                   stride) {
+    GGML_ASSERT(a->ne[3] == b->ne[2]);
+
+    const int64_t ne[4] = {
+        ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
+        ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
+        a->ne[2], b->ne[3],
+    };
+
+    struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_i32(result, 0, stride);
+
+    result->op     = GGML_OP_CONV_TRANSPOSE_2D;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_pool_*
+
+static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) {
+    return (ins + 2 * p - ks) / s + 1;
+}
+
+// ggml_pool_1d
+
+struct ggml_tensor * ggml_pool_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_op_pool     op,
+        int                   k0,
+        int                   s0,
+        int                   p0) {
+    const int64_t ne[4] = {
+        ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
+        a->ne[1],
+        a->ne[2],
+        a->ne[3],
+    };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    int32_t params[] = { op, k0, s0, p0 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_POOL_1D;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_pool_2d
+
+struct ggml_tensor * ggml_pool_2d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_op_pool     op,
+        int                   k0,
+        int                   k1,
+        int                   s0,
+        int                   s1,
+        float                 p0,
+        float                 p1) {
+    struct ggml_tensor * result;
+    const int64_t ne[4] = {
+        ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
+        ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
+        a->ne[2],
+        a->ne[3],
+    };
+    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_POOL_2D;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_pool_2d_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * af,
+        enum ggml_op_pool     op,
+        int                   k0,
+        int                   k1,
+        int                   s0,
+        int                   s1,
+        float                 p0,
+        float                 p1) {
+    struct ggml_tensor * result;
+    result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne);
+
+    int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_POOL_2D_BACK;
+    result->src[0] = a;
+    result->src[1] = af;
+
+    return result;
+}
+
+// ggml_upscale
+
+static struct ggml_tensor * ggml_upscale_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2,
+        int                   ne3) {
+    GGML_ASSERT(a->ne[0] <= ne0);
+    GGML_ASSERT(a->ne[1] <= ne1);
+    GGML_ASSERT(a->ne[2] <= ne2);
+    GGML_ASSERT(a->ne[3] <= ne3);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
+
+    result->op     = GGML_OP_UPSCALE;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_upscale(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   scale_factor) {
+    return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
+}
+
+struct ggml_tensor * ggml_upscale_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   ne0,
+        int                   ne1,
+        int                   ne2,
+        int                   ne3) {
+    return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
+}
+
+// ggml_pad
+
+struct ggml_tensor * ggml_pad(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   p0,
+        int                   p1,
+        int                   p2,
+        int                   p3) {
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] + p0,
+            a->ne[1] + p1,
+            a->ne[2] + p2,
+            a->ne[3] + p3);
+
+    result->op     = GGML_OP_PAD;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_pad_reflect_1d
+
+struct ggml_tensor * ggml_pad_reflect_1d(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   p0,
+        int                   p1) {
+    GGML_ASSERT(p0 >= 0);
+    GGML_ASSERT(p1 >= 0);
+
+    GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the
+    GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded
+
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(a->type == GGML_TYPE_F32);
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] + p0 + p1,
+            a->ne[1],
+            a->ne[2],
+            a->ne[3]);
+
+    int32_t params[] = { p0, p1 };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_PAD_REFLECT_1D;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_unpad
+
+struct ggml_tensor * ggml_unpad(
+    struct ggml_context * ctx,
+    struct ggml_tensor  * a,
+    int p0, int p1, int p2, int p3) {
+
+    struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+            a->ne[0] - p0,
+            a->ne[1] - p1,
+            a->ne[2] - p2,
+            a->ne[3] - p3);
+
+    result->op = GGML_OP_UNPAD;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_arange
+
+struct ggml_tensor * ggml_arange(
+        struct ggml_context * ctx,
+        float                 start,
+        float                 stop,
+        float                 step) {
+    GGML_ASSERT(stop > start);
+
+    const int64_t steps = (int64_t) ceilf((stop - start) / step);
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps);
+
+    ggml_set_op_params_f32(result, 0, start);
+    ggml_set_op_params_f32(result, 1, stop);
+    ggml_set_op_params_f32(result, 2, step);
+
+    result->op = GGML_OP_ARANGE;
+
+    return result;
+}
+
+// ggml_timestep_embedding
+
+struct ggml_tensor * ggml_timestep_embedding(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * timesteps,
+        int                   dim,
+        int                   max_period) {
+    int actual_dim = dim;
+    if (dim % 2 != 0) {
+        actual_dim = dim + 1;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
+
+    ggml_set_op_params_i32(result, 0, dim);
+    ggml_set_op_params_i32(result, 1, max_period);
+
+    result->op     = GGML_OP_TIMESTEP_EMBEDDING;
+    result->src[0] = timesteps;
+
+    return result;
+}
+
+// ggml_argsort
+
+struct ggml_tensor * ggml_argsort(
+        struct ggml_context  * ctx,
+        struct ggml_tensor   * a,
+        enum ggml_sort_order   order) {
+    GGML_ASSERT(a->ne[0] <= INT32_MAX);
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) order);
+
+    result->op     = GGML_OP_ARGSORT;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_top_k
+
+struct ggml_tensor * ggml_top_k(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   k) {
+    GGML_ASSERT(a->ne[0] >= k);
+
+    struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_ORDER_DESC);
+
+    result = ggml_view_4d(ctx, result,
+                k, result->ne[1], result->ne[2], result->ne[3],
+                   result->nb[1], result->nb[2], result->nb[3],
+                0);
+
+    return result;
+}
+
+// ggml_flash_attn_ext
+
+struct ggml_tensor * ggml_flash_attn_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * mask,
+        float                 scale,
+        float                 max_bias,
+        float                 logit_softcap) {
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+
+    if (mask) {
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(mask->ne[2] == 1);
+        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
+                "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
+        //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
+    }
+
+    if (max_bias > 0.0f) {
+        GGML_ASSERT(mask);
+    }
+
+    // permute(0, 2, 1, 3)
+    int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    float params[] = { scale, max_bias, logit_softcap };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_FLASH_ATTN_EXT;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = mask;
+
+    return result;
+}
+
+void ggml_flash_attn_ext_set_prec(
+        struct ggml_tensor * a,
+        enum ggml_prec       prec) {
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int32_t prec_i32 = (int32_t) prec;
+
+    ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
+}
+
+enum ggml_prec ggml_flash_attn_ext_get_prec(
+        const struct ggml_tensor * a) {
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
+
+    return (enum ggml_prec) prec_i32;
+}
+
+// ggml_flash_attn_back
+
+struct ggml_tensor * ggml_flash_attn_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * d,
+        bool                  masked) {
+    GGML_ABORT("TODO: adapt to ggml_flash_attn_ext() changes");
+
+    GGML_ASSERT(ggml_can_mul_mat(k, q));
+    // TODO: check if vT can be multiplied by (k*qT)
+
+    // d shape [D,N,ne2,ne3]
+    // q shape [D,N,ne2,ne3]
+    // k shape [D,M,kvne2,ne3]
+    // v shape [M,D,kvne2,ne3]
+
+    const int64_t     D = q->ne[0];
+    const int64_t     N = q->ne[1];
+    const int64_t     M = k->ne[1];
+    const int64_t   ne2 = q->ne[2];
+    const int64_t   ne3 = q->ne[3];
+    const int64_t kvne2 = k->ne[2];
+
+    GGML_ASSERT(k->ne[0] == D);
+    GGML_ASSERT(v->ne[0] == M);
+    GGML_ASSERT(v->ne[1] == D);
+    GGML_ASSERT(d->ne[0] == D);
+    GGML_ASSERT(d->ne[1] == N);
+    GGML_ASSERT(k->ne[2] == kvne2);
+    GGML_ASSERT(k->ne[3] == ne3);
+    GGML_ASSERT(v->ne[2] == kvne2);
+    GGML_ASSERT(v->ne[3] == ne3);
+    GGML_ASSERT(d->ne[2] == ne2);
+    GGML_ASSERT(d->ne[3] == ne3);
+
+    GGML_ASSERT(ne2 % kvne2 == 0);
+
+    // store gradients of q, k and v as continuous tensors concatenated in result.
+    // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
+    const int64_t elem_q = ggml_nelements(q);
+    const int64_t elem_k = ggml_nelements(k);
+    const int64_t elem_v = ggml_nelements(v);
+
+    enum ggml_type result_type = GGML_TYPE_F32;
+    GGML_ASSERT(ggml_blck_size(result_type) == 1);
+    const size_t tsize = ggml_type_size(result_type);
+
+    const size_t offs_q = 0;
+    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
+    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
+    const size_t end    = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
+
+    const size_t nelements = (end + tsize - 1)/tsize;
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
+
+    int32_t masked_i = masked ? 1 : 0;
+    ggml_set_op_params(result, &masked_i, sizeof(masked_i));
+
+    result->op     = GGML_OP_FLASH_ATTN_BACK;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = d;
+
+    return result;
+}
+
+// ggml_ssm_conv
+
+struct ggml_tensor * ggml_ssm_conv(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * sx,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(ggml_is_3d(sx));
+    GGML_ASSERT(ggml_is_matrix(c));
+
+    const int64_t d_conv  = c->ne[0];
+    const int64_t d_inner = c->ne[1];
+    const int64_t n_t     = sx->ne[0] - d_conv + 1; // tokens per sequence
+    const int64_t n_s     = sx->ne[2];
+
+    // TODO: maybe support other strides than 1?
+    // FIXME: this is always true?
+    GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
+    GGML_ASSERT(sx->ne[1] == d_inner);
+    GGML_ASSERT(n_t >= 0);
+
+    struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
+
+    result->op     = GGML_OP_SSM_CONV;
+    result->src[0] = sx;
+    result->src[1] = c;
+
+    return result;
+}
+
+// ggml_ssm_scan
+
+struct ggml_tensor * ggml_ssm_scan(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * s,
+        struct ggml_tensor  * x,
+        struct ggml_tensor  * dt,
+        struct ggml_tensor  * A,
+        struct ggml_tensor  * B,
+        struct ggml_tensor  * C) {
+    GGML_ASSERT(ggml_is_contiguous(s));
+    GGML_ASSERT(ggml_is_contiguous(x));
+    GGML_ASSERT(ggml_is_contiguous(dt));
+    GGML_ASSERT(ggml_is_contiguous(A));
+    GGML_ASSERT(ggml_is_matrix(A));
+    GGML_ASSERT(ggml_is_3d(B));
+    GGML_ASSERT(ggml_is_3d(s));
+    GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
+    GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
+    GGML_ASSERT(ggml_are_same_shape(x, dt));
+    GGML_ASSERT(ggml_are_same_shape(B, C));
+
+    {
+        const int64_t d_state      = s->ne[0];
+        const int64_t d_inner      = s->ne[1];
+        const int64_t n_seq_tokens = x->ne[1];
+        const int64_t n_seqs       = x->ne[2];
+
+        GGML_ASSERT(s->ne[2] == n_seqs);
+        GGML_ASSERT(x->ne[0] == d_inner);
+        GGML_ASSERT(A->ne[0] == d_state);
+        GGML_ASSERT(A->ne[1] == d_inner);
+        GGML_ASSERT(B->ne[0] == d_state);
+        GGML_ASSERT(B->ne[1] == n_seq_tokens);
+        GGML_ASSERT(B->ne[2] == n_seqs);
+    }
+
+    // concatenated y + ssm_states
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
+
+    result->op   = GGML_OP_SSM_SCAN;
+    result->src[0] = s;
+    result->src[1] = x;
+    result->src[2] = dt;
+    result->src[3] = A;
+    result->src[4] = B;
+    result->src[5] = C;
+
+    return result;
+}
+
+// ggml_win_part
+
+struct ggml_tensor * ggml_win_part(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   w) {
+    GGML_ASSERT(a->ne[3] == 1);
+    GGML_ASSERT(a->type  == GGML_TYPE_F32);
+
+    // padding
+    const int px = (w - a->ne[1]%w)%w;
+    const int py = (w - a->ne[2]%w)%w;
+
+    const int npx = (px + a->ne[1])/w;
+    const int npy = (py + a->ne[2])/w;
+    const int np  = npx*npy;
+
+    const int64_t ne[4] = { a->ne[0], w, w, np, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    int32_t params[] = { npx, npy, w };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_WIN_PART;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_win_unpart
+
+struct ggml_tensor * ggml_win_unpart(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   w0,
+        int                   h0,
+        int                   w) {
+    GGML_ASSERT(a->type == GGML_TYPE_F32);
+
+    const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
+
+    int32_t params[] = { w };
+    ggml_set_op_params(result, params, sizeof(params));
+
+    result->op     = GGML_OP_WIN_UNPART;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_get_rel_pos
+
+struct ggml_tensor * ggml_get_rel_pos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        int                   qh,
+        int                   kh) {
+    GGML_ASSERT(qh == kh);
+    GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
+
+    const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne);
+
+    result->op     = GGML_OP_GET_REL_POS;
+    result->src[0] = a;
+
+    return result;
+}
+
+// ggml_add_rel_pos
+
+static struct ggml_tensor * ggml_add_rel_pos_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_are_same_shape(pw, ph));
+    GGML_ASSERT(ggml_is_contiguous(a));
+    GGML_ASSERT(ggml_is_contiguous(pw));
+    GGML_ASSERT(ggml_is_contiguous(ph));
+    GGML_ASSERT(ph->type == GGML_TYPE_F32);
+    GGML_ASSERT(pw->type == GGML_TYPE_F32);
+    GGML_ASSERT(pw->ne[3] == a->ne[2]);
+    GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
+    GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
+
+    result->op     = GGML_OP_ADD_REL_POS;
+    result->src[0] = a;
+    result->src[1] = pw;
+    result->src[2] = ph;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_add_rel_pos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph) {
+    return ggml_add_rel_pos_impl(ctx, a, pw, ph, false);
+}
+
+struct ggml_tensor * ggml_add_rel_pos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * pw,
+        struct ggml_tensor  * ph) {
+    return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
+}
+
+// ggml_rwkv_wkv6
+
+struct ggml_tensor * ggml_rwkv_wkv6(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * r,
+        struct ggml_tensor  * tf,
+        struct ggml_tensor  * td,
+        struct ggml_tensor  * state) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(r));
+    GGML_ASSERT(ggml_is_contiguous(tf));
+    GGML_ASSERT(ggml_is_contiguous(td));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[2];
+    const int64_t n_tokens = k->ne[3];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(k->ne[1] == 1);
+        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
+        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
+        // TODO: RWKV v4 and v5
+        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_RWKV_WKV6;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = r;
+    result->src[3] = tf;
+    result->src[4] = td;
+    result->src[5] = state;
+
+    return result;
+}
+
+// ggml_unary
+
+static struct ggml_tensor * ggml_unary_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op    op,
+        bool                  inplace) {
+    GGML_ASSERT(ggml_is_contiguous_1(a));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) op);
+
+    result->op     = GGML_OP_UNARY;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_unary(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op    op) {
+    return ggml_unary_impl(ctx, a, op, false);
+}
+
+struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op    op) {
+    return ggml_unary_impl(ctx, a, op, true);
+}
+
+// ggml_map_unary
+
+static struct ggml_tensor * ggml_map_unary_impl_f32(
+        struct ggml_context        * ctx,
+        struct ggml_tensor         * a,
+        const  ggml_unary_op_f32_t   fun,
+        bool                         inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+    result->op     = GGML_OP_MAP_UNARY;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_unary_f32(
+        struct ggml_context        * ctx,
+        struct ggml_tensor         * a,
+        const  ggml_unary_op_f32_t   fun) {
+    return ggml_map_unary_impl_f32(ctx, a, fun, false);
+}
+
+struct ggml_tensor * ggml_map_unary_inplace_f32(
+        struct ggml_context        * ctx,
+        struct ggml_tensor         * a,
+        const  ggml_unary_op_f32_t   fun) {
+    return ggml_map_unary_impl_f32(ctx, a, fun, true);
+}
+
+// ggml_map_binary
+
+static struct ggml_tensor * ggml_map_binary_impl_f32(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b,
+        const  ggml_binary_op_f32_t   fun,
+        bool                          inplace) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+    result->op     = GGML_OP_MAP_BINARY;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_binary_f32(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b,
+        const  ggml_binary_op_f32_t   fun) {
+    return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_binary_inplace_f32(
+        struct ggml_context         * ctx,
+        struct ggml_tensor          * a,
+        struct ggml_tensor          * b,
+        const  ggml_binary_op_f32_t   fun) {
+    return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
+}
+
+// ggml_map_custom1_f32
+
+static struct ggml_tensor * ggml_map_custom1_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun,
+        bool                           inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+    result->op     = GGML_OP_MAP_CUSTOM1_F32;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom1_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun) {
+    return ggml_map_custom1_impl_f32(ctx, a, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom1_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        const  ggml_custom1_op_f32_t   fun) {
+    return ggml_map_custom1_impl_f32(ctx, a, fun, true);
+}
+
+// ggml_map_custom2_f32
+
+static struct ggml_tensor * ggml_map_custom2_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun,
+        bool                           inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+    result->op     = GGML_OP_MAP_CUSTOM2_F32;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom2_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun) {
+    return ggml_map_custom2_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom2_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        const  ggml_custom2_op_f32_t   fun) {
+    return ggml_map_custom2_impl_f32(ctx, a, b, fun, true);
+}
+
+// ggml_map_custom3_f32
+
+static struct ggml_tensor * ggml_map_custom3_impl_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun,
+        bool                           inplace) {
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
+
+    result->op     = GGML_OP_MAP_CUSTOM3_F32;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom3_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun) {
+    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false);
+}
+
+struct ggml_tensor * ggml_map_custom3_inplace_f32(
+        struct ggml_context          * ctx,
+        struct ggml_tensor           * a,
+        struct ggml_tensor           * b,
+        struct ggml_tensor           * c,
+        const  ggml_custom3_op_f32_t   fun) {
+    return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true);
+}
+
+// ggml_map_custom1
+
+static struct ggml_tensor * ggml_map_custom1_impl(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        const  ggml_custom1_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata,
+        bool                       inplace) {
+    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    struct ggml_map_custom1_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
+
+    result->op     = GGML_OP_MAP_CUSTOM1;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom1(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        const  ggml_custom1_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
+    return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
+}
+
+struct ggml_tensor * ggml_map_custom1_inplace(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        const  ggml_custom1_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
+    return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
+}
+
+// ggml_map_custom2
+
+static struct ggml_tensor * ggml_map_custom2_impl(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        const  ggml_custom2_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata,
+        bool                       inplace) {
+    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    struct ggml_map_custom2_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
+
+    result->op     = GGML_OP_MAP_CUSTOM2;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom2(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        const  ggml_custom2_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
+    return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
+}
+
+struct ggml_tensor * ggml_map_custom2_inplace(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        const  ggml_custom2_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
+    return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
+}
+
+// ggml_map_custom3
+
+static struct ggml_tensor * ggml_map_custom3_impl(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        struct ggml_tensor       * c,
+        const  ggml_custom3_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata,
+        bool                       inplace) {
+    GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0);
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    struct ggml_map_custom3_op_params params = {
+        /*.fun      =*/ fun,
+        /*.n_tasks  =*/ n_tasks,
+        /*.userdata =*/ userdata
+    };
+    ggml_set_op_params(result, (const void *) ¶ms, sizeof(params));
+
+    result->op     = GGML_OP_MAP_CUSTOM3;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_map_custom3(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        struct ggml_tensor       * c,
+        const  ggml_custom3_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
+    return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
+}
+
+struct ggml_tensor * ggml_map_custom3_inplace(
+        struct ggml_context      * ctx,
+        struct ggml_tensor       * a,
+        struct ggml_tensor       * b,
+        struct ggml_tensor       * c,
+        const  ggml_custom3_op_t   fun,
+        int                        n_tasks,
+        void                     * userdata) {
+    return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
+}
+
+// ggml_cross_entropy_loss
+
+struct ggml_tensor * ggml_cross_entropy_loss(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1);
+
+    result->op     = GGML_OP_CROSS_ENTROPY_LOSS;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+// ggml_cross_entropy_loss_back
+
+struct ggml_tensor * ggml_cross_entropy_loss_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+    GGML_ASSERT(ggml_is_scalar(c));
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = c;
+
+    return result;
+}
+
+// opt_step_adamw
+
+struct ggml_tensor * ggml_opt_step_adamw(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * grad,
+        struct ggml_tensor  * m,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * adamw_params) {
+    GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
+    GGML_ASSERT(ggml_are_same_shape(a, grad));
+    GGML_ASSERT(ggml_are_same_shape(a, m));
+    GGML_ASSERT(ggml_are_same_shape(a, v));
+    GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op     = GGML_OP_OPT_STEP_ADAMW;
+    result->src[0] = a;
+    result->src[1] = grad;
+    result->src[2] = m;
+    result->src[3] = v;
+    result->src[4] = adamw_params;
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct ggml_hash_set ggml_hash_set_new(size_t size) {
+    size = ggml_hash_size(size);
+    struct ggml_hash_set result;
+    result.size = size;
+    result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size);
+    result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t));
+    return result;
+}
+
+void ggml_hash_set_reset(struct ggml_hash_set * hash_set) {
+    memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size));
+}
+
+void ggml_hash_set_free(struct ggml_hash_set * hash_set) {
+    GGML_FREE(hash_set->used);
+    GGML_FREE(hash_set->keys);
+}
+
+size_t ggml_hash_size(size_t min_sz) {
+    // next primes after powers of two
+    static const size_t primes[] = {
+        2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031,
+        2053, 4099, 8209, 16411, 32771, 65537, 131101,
+        262147, 524309, 1048583, 2097169, 4194319, 8388617,
+        16777259, 33554467, 67108879, 134217757, 268435459,
+        536870923, 1073741827, 2147483659
+    };
+    static const size_t n_primes = sizeof(primes)/sizeof(primes[0]);
+
+    // find the smallest prime that is larger or equal than min_sz
+    size_t l = 0;
+    size_t r = n_primes;
+    while (l < r) {
+        size_t m = (l + r)/2;
+        if (primes[m] < min_sz) {
+            l = m + 1;
+        } else {
+            r = m;
+        }
+    }
+    size_t sz = l < n_primes ? primes[l] : min_sz | 1;
+    return sz;
+}
+
+struct hash_map {
+    struct ggml_hash_set set;
+    struct ggml_tensor ** vals;
+};
+
+static struct hash_map * ggml_new_hash_map(size_t size) {
+    struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map));
+    result->set = ggml_hash_set_new(size);
+    result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *));
+    return result;
+}
+
+static void ggml_hash_map_free(struct hash_map * map) {
+    ggml_hash_set_free(&map->set);
+    GGML_FREE(map->vals);
+    GGML_FREE(map);
+}
+
+// utility functions to change gradients
+// isrc is the index of tensor in cgraph->visited_has_set.keys
+// the corresponding gradient (accumulators) are also at position isrc
+// if tensor has a gradient accumulator, modify that accumulator in-place
+// else if there is no gradient for tensor, set the corresponding value
+// else, just add/subtract/etc. the gradients
+
+static void ggml_add_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * tensor) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
+    } else {
+        cgraph->grads[isrc] = tensor;
+    }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
+}
+
+static void ggml_acc_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * tensor,
+        const  size_t         nb1,
+        const  size_t         nb2,
+        const  size_t         nb3,
+        const  size_t         offset) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
+    } else {
+        struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
+        cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
+    }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
+}
+
+static void ggml_add1_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * tensor) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+    } else {
+        cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
+    }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
+}
+
+static void ggml_sub_or_set(
+        struct ggml_context * ctx,
+        struct ggml_cgraph  * cgraph,
+        size_t                isrc,
+        struct ggml_tensor  * tensor) {
+    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
+    GGML_ASSERT(src);
+    if (cgraph->grads[isrc]) {
+        cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
+    } else {
+        cgraph->grads[isrc] = ggml_neg(ctx, tensor);
+    }
+    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
+    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
+}
+
+static void ggml_compute_backward(
+        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+    struct ggml_tensor * tensor = cgraph->nodes[i];
+    struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);
+
+    if (!grad) {
+        return;
+    }
+
+    struct ggml_tensor * src0 = tensor->src[0];
+    struct ggml_tensor * src1 = tensor->src[1];
+    struct ggml_tensor * src2 = tensor->src[2];
+    struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
+    const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
+    const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
+    const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
+    const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
+    const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
+    const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
+
+    switch (tensor->op) {
+        case GGML_OP_DUP: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+        } break;
+        case GGML_OP_ADD: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+            if (src1_needs_grads) {
+                struct ggml_tensor * tmp = grad;
+                if (!ggml_are_same_shape(src0, src1)) {
+                    tmp = ggml_repeat_back(ctx, tmp, src1);
+                }
+                ggml_add_or_set(ctx, cgraph, isrc1, tmp);
+            }
+        } break;
+        case GGML_OP_ADD1: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+            if (src1_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
+            }
+        } break;
+        case GGML_OP_ACC: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+            if (src1_needs_grads) {
+                const size_t nb1    = ((int32_t *) tensor->op_params)[0];
+                const size_t nb2    = ((int32_t *) tensor->op_params)[1];
+                const size_t nb3    = ((int32_t *) tensor->op_params)[2];
+                const size_t offset = ((int32_t *) tensor->op_params)[3];
+
+                struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
+                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                    nb1, nb2, nb3, offset);
+
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
+            }
+        } break;
+        case GGML_OP_SUB: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+            if (src1_needs_grads) {
+                ggml_sub_or_set(ctx, cgraph, isrc1, grad);
+            }
+        } break;
+        case GGML_OP_MUL: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+            }
+            if (src1_needs_grads) {
+                struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
+                if (!ggml_are_same_shape(src0, src1)) {
+                    tmp = ggml_repeat_back(ctx, tmp, src1);
+                }
+                ggml_add_or_set(ctx, cgraph, isrc1, tmp);
+            }
+        } break;
+        case GGML_OP_DIV: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
+            }
+            if (src1_needs_grads) {
+                ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
+            }
+        } break;
+        case GGML_OP_SQR: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
+            }
+        } break;
+        case GGML_OP_SQRT: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
+            }
+        } break;
+        case GGML_OP_LOG: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
+            }
+        } break;
+        case GGML_OP_SIN: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
+            }
+        } break;
+        case GGML_OP_COS: {
+            if (src0_needs_grads) {
+                ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
+            }
+        } break;
+        case GGML_OP_SUM: {
+            if (src0_needs_grads) {
+                ggml_add1_or_set(ctx, cgraph, isrc0, grad);
+            }
+        } break;
+        case GGML_OP_SUM_ROWS: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
+            }
+        } break;
+        case GGML_OP_MEAN: {
+            if (src0_needs_grads) {
+                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
+            }
+        } break;
+        case GGML_OP_REPEAT: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
+            }
+        } break;
+        case GGML_OP_REPEAT_BACK: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
+            }
+        } break;
+        case GGML_OP_RMS_NORM: {
+            if (src0_needs_grads) {
+                float eps;
+                memcpy(&eps, tensor->op_params, sizeof(float));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+            }
+        } break;
+        case GGML_OP_MUL_MAT: {
+            // https://cs231n.github.io/optimization-2/#staged
+            // # forward pass
+            // s0 = np.random.randn(5, 10)
+            // s1 = np.random.randn(10, 3)
+            // t = s0.dot(s1)
+
+            // # now suppose we had the gradient on t from above in the circuit
+            // dt = np.random.randn(*t.shape) # same shape as t
+            // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
+            // ds1 = t.T.dot(dt)
+
+            // tensor.shape [m,p,qq,rr]
+            // src0.shape   [n,m,q1,r1]
+            // src1.shape   [n,p,qq,rr]
+
+            if (src0_needs_grads) {
+                struct ggml_tensor * s1_tg =
+                    ggml_out_prod(ctx, // [n,m,qq,rr]
+                        src1,          // [n,p,qq,rr]
+                        grad);         // [m,p,qq,rr]
+                const int64_t qq = s1_tg->ne[2];
+                const int64_t rr = s1_tg->ne[3];
+                const int64_t q1 = src0->ne[2];
+                const int64_t r1 = src0->ne[3];
+                const bool ne2_broadcasted = qq > q1;
+                const bool ne3_broadcasted = rr > r1;
+                if (ne2_broadcasted || ne3_broadcasted) {
+                    // sum broadcast repetitions of s1_tg into shape of src0
+                    s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+                }
+                ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+            }
+            if (src1_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc1,
+                        // ggml_mul_mat(ctx,                   // [n,p,qq,rr]
+                        //     ggml_cont(ctx,                  // [m,n,q1,r1]
+                        //         ggml_transpose(ctx, src0)), // [m,n,q1,r1]
+                        //     grad),                          // [m,p,qq,rr]
+
+                        // when src0 is bigger than tensor->grad (this is mostly the case in llama),
+                        // avoid transpose of src0, rather transpose smaller tensor->grad
+                        // and then use ggml_out_prod
+                        ggml_out_prod(ctx,      // [n,p,qq,rr]
+                            src0,               // [n,m,q1,r1]
+                            ggml_transpose(ctx, // [p,m,qq,rr]
+                                grad)));        // [m,p,qq,rr]
+            }
+        } break;
+        case GGML_OP_SCALE: {
+            if (src0_needs_grads) {
+                float s;
+                memcpy(&s, tensor->op_params, sizeof(float));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
+            }
+        } break;
+        case GGML_OP_SET: {
+            const size_t nb1    = ((const int32_t *) tensor->op_params)[0];
+            const size_t nb2    = ((const int32_t *) tensor->op_params)[1];
+            const size_t nb3    = ((const int32_t *) tensor->op_params)[2];
+            const size_t offset = ((const int32_t *) tensor->op_params)[3];
+
+            struct ggml_tensor * tensor_grad_view = NULL;
+
+            if (src0_needs_grads || src1_needs_grads) {
+                GGML_ASSERT(src0->type == tensor->type);
+                GGML_ASSERT(!cgraph->grads[isrc0] ||                      cgraph->grads[isrc0]->type == grad->type);
+                GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
+
+                tensor_grad_view = ggml_view_4d(ctx,
+                    grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+                    nb1, nb2, nb3, offset);
+            }
+
+            if (src0_needs_grads) {
+                struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
+            }
+
+            if (src1_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
+            }
+        } break;
+        case GGML_OP_CPY: {
+            // cpy overwrites value of src1 by src0 and returns view(src1)
+            // the overwriting is mathematically equivalent to:
+            // tensor = src0 * 1 + src1 * 0
+            if (src0_needs_grads) {
+                // dsrc0 = dtensor * 1
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+            if (src1_needs_grads) {
+                // dsrc1 = dtensor * 0 -> noop
+            }
+        } break;
+        case GGML_OP_CONT: {
+            // same as cpy
+            if (src0_needs_grads) {
+                GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
+                GGML_ASSERT(ggml_is_contiguous(grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+            }
+        } break;
+        case GGML_OP_RESHAPE: {
+            if (src0_needs_grads) {
+                struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
+            }
+        } break;
+        case GGML_OP_VIEW: {
+            if (src0_needs_grads) {
+                size_t offset;
+
+                memcpy(&offset, tensor->op_params, sizeof(offset));
+
+                size_t nb1 = tensor->nb[1];
+                size_t nb2 = tensor->nb[2];
+                size_t nb3 = tensor->nb[3];
+
+                if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
+                    // gradient is typically F32, but src0 could be other type
+                    size_t ng = ggml_element_size(cgraph->grads[isrc0]);
+                    size_t n0 = ggml_element_size(src0);
+                    GGML_ASSERT(offset % n0 == 0);
+                    GGML_ASSERT(nb1 % n0 == 0);
+                    GGML_ASSERT(nb2 % n0 == 0);
+                    GGML_ASSERT(nb3 % n0 == 0);
+                    offset = (offset / n0) * ng;
+                    nb1 = (nb1 / n0) * ng;
+                    nb2 = (nb2 / n0) * ng;
+                    nb3 = (nb3 / n0) * ng;
+                }
+
+                ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
+            }
+        } break;
+        case GGML_OP_PERMUTE: {
+            if (src0_needs_grads) {
+                const int32_t * axes = (const int32_t *) tensor->op_params;
+                const int axis0 = axes[0] & 0x3;
+                const int axis1 = axes[1] & 0x3;
+                const int axis2 = axes[2] & 0x3;
+                const int axis3 = axes[3] & 0x3;
+                int axb[4] = {0,0,0,0}; // axes backward
+                axb[axis0] = 0;
+                axb[axis1] = 1;
+                axb[axis2] = 2;
+                axb[axis3] = 3;
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
+            }
+        } break;
+        case GGML_OP_TRANSPOSE: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
+            }
+        } break;
+        case GGML_OP_GET_ROWS: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
+            }
+            if (src1_needs_grads) {
+                // noop
+            }
+        } break;
+        case GGML_OP_DIAG_MASK_INF: {
+            if (src0_needs_grads) {
+                /* ggml_diag_mask_inf_impl() shouldn't be here */
+                /* ref:  https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
+                const int n_past = ((const int32_t *) tensor->op_params)[0];
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
+            }
+        } break;
+        case GGML_OP_DIAG_MASK_ZERO: {
+            if (src0_needs_grads) {
+                const int n_past = ((const int32_t *) tensor->op_params)[0];
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
+            }
+        } break;
+        case GGML_OP_SOFT_MAX: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+            }
+            GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
+        } break;
+        case GGML_OP_ROPE: {
+            if (src0_needs_grads) {
+                //const int n_past = ((int32_t *) tensor->op_params)[0];
+                const int n_dims     = ((const int32_t *) tensor->op_params)[1];
+                const int mode       = ((const int32_t *) tensor->op_params)[2];
+                //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
+                const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
+                float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+
+                memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));
+                memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));
+                memcpy(&ext_factor,  (const float *) tensor->op_params +  7, sizeof(float));
+                memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));
+                memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));
+                memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));
+
+                ggml_add_or_set(ctx, cgraph, isrc0,
+                    ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
+                        freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
+            }
+            GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
+        } break;
+        case GGML_OP_IM2COL: {
+            if (src1_needs_grads) {
+                const int32_t s0    = ggml_get_op_params_i32(tensor, 0);
+                const int32_t s1    = ggml_get_op_params_i32(tensor, 1);
+                const int32_t p0    = ggml_get_op_params_i32(tensor, 2);
+                const int32_t p1    = ggml_get_op_params_i32(tensor, 3);
+                const int32_t d0    = ggml_get_op_params_i32(tensor, 4);
+                const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
+                const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
+
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+            }
+        } break;
+        case GGML_OP_POOL_2D: {
+            if (src0_needs_grads) {
+                const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
+                const      int32_t      k0 = ggml_get_op_params_i32(tensor, 1);
+                const      int32_t      k1 = ggml_get_op_params_i32(tensor, 2);
+                const      int32_t      s0 = ggml_get_op_params_i32(tensor, 3);
+                const      int32_t      s1 = ggml_get_op_params_i32(tensor, 4);
+                const      int32_t      p0 = ggml_get_op_params_i32(tensor, 5);
+                const      int32_t      p1 = ggml_get_op_params_i32(tensor, 6);
+
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
+            }
+        } break;
+        case GGML_OP_WIN_PART:
+        case GGML_OP_WIN_UNPART:
+        case GGML_OP_UNARY: {
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_ABS: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
+                    }
+                } break;
+                case GGML_UNARY_OP_SGN: {
+                    // noop
+                } break;
+                case GGML_UNARY_OP_NEG: {
+                    if (src0_needs_grads) {
+                        ggml_sub_or_set(ctx, cgraph, isrc0, grad);
+                    }
+                } break;
+                case GGML_UNARY_OP_STEP: {
+                    // noop
+                } break;
+                case GGML_UNARY_OP_RELU: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
+                    }
+                } break;
+                case GGML_UNARY_OP_SILU: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+                    }
+                } break;
+                case GGML_UNARY_OP_EXP: {
+                    if (src0_needs_grads) {
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
+                    }
+                } break;
+                default: {
+                    fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
+                        __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
+                    GGML_ABORT("fatal error");
+                } //break;
+            }
+        } break;
+        case GGML_OP_CROSS_ENTROPY_LOSS: {
+            if (src0_needs_grads) {
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
+            }
+            GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
+        } break;
+        case GGML_OP_NONE: {
+            // noop
+        } break;
+        case GGML_OP_COUNT:
+        default: {
+            fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
+            GGML_ABORT("fatal error");
+        } //break;
+    }
+
+    GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
+    GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
+    GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
+}
+
+static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+    // check if already visited
+    if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
+        return;
+    }
+
+    for (int i = 0; i < GGML_MAX_SRC; ++i) {
+        const int k =
+            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
+            (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
+            /* unknown order, just fall back to using i*/ i;
+        if (node->src[k]) {
+            ggml_visit_parents(cgraph, node->src[k]);
+        }
+    }
+
+    if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) {
+        // reached a leaf node, not part of the gradient graph (e.g. a constant)
+        GGML_ASSERT(cgraph->n_leafs < cgraph->size);
+
+        if (strlen(node->name) == 0) {
+            ggml_format_name(node, "leaf_%d", cgraph->n_leafs);
+        }
+
+        cgraph->leafs[cgraph->n_leafs] = node;
+        cgraph->n_leafs++;
+    } else {
+        GGML_ASSERT(cgraph->n_nodes < cgraph->size);
+
+        if (strlen(node->name) == 0) {
+            ggml_format_name(node, "node_%d", cgraph->n_nodes);
+        }
+
+        cgraph->nodes[cgraph->n_nodes] = node;
+        cgraph->n_nodes++;
+    }
+}
+
+static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
+    if (!expand) {
+        // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand
+        ggml_graph_clear(cgraph);
+    }
+
+    const int n0 = cgraph->n_nodes;
+
+    ggml_visit_parents(cgraph, tensor);
+
+    const int n_new = cgraph->n_nodes - n0;
+    GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new);
+
+    if (n_new > 0) {
+        // the last added node should always be starting point
+        GGML_ASSERT(cgraph->nodes[cgraph->n_nodes - 1] == tensor);
+    }
+}
+
+void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+    ggml_build_forward_impl(cgraph, tensor, true);
+}
+
+void ggml_build_backward_expand(
+        struct ggml_context * ctx_static,
+        struct ggml_context * ctx_compute,
+        struct ggml_cgraph  * cgraph,
+        bool                  accumulate) {
+    GGML_ASSERT(cgraph->n_nodes > 0);
+    GGML_ASSERT(cgraph->grads);
+    GGML_ASSERT(cgraph->grad_accs);
+
+    const int n_nodes_f = cgraph->n_nodes;
+
+    memset(cgraph->grads,     0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
+    memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
+    bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
+
+    {
+        bool any_params = false;
+        bool any_loss   = false;
+        for (int i = 0; i < n_nodes_f; ++i) {
+            struct ggml_tensor * node = cgraph->nodes[i];
+            any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
+            any_loss   = any_loss   || (node->flags & GGML_TENSOR_FLAG_LOSS);
+        }
+        GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
+        GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
+    }
+
+    for (int i = 0; i < n_nodes_f; ++i) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        if (node->type == GGML_TYPE_I32) {
+            continue;
+        }
+
+        bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
+        bool ignore_src[GGML_MAX_SRC] = {false};
+        switch (node->op) {
+            // gradients in node->src[0] for one reason or another have no effect on output gradients
+            case GGML_OP_IM2COL:      // only used for its shape
+            case GGML_OP_IM2COL_BACK: // same as IM2COL
+                ignore_src[0] = true;
+                break;
+            case GGML_OP_UNARY: {
+                const enum ggml_unary_op uop = ggml_get_unary_op(node);
+                // SGN and STEP unary ops are piecewise constant
+                if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) {
+                    ignore_src[0] = true;
+                }
+            } break;
+
+            // gradients in node->src[1] for one reason or another have no effect on output gradients
+            case GGML_OP_CPY:           // gradients in CPY target are irrelevant
+            case GGML_OP_GET_ROWS:      // row indices not differentiable
+            case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
+            case GGML_OP_ROPE:          // positions not differentiable
+                ignore_src[1] = true;
+                break;
+
+            default:
+                break;
+        }
+        for (int j = 0; j < GGML_MAX_SRC; ++j) {
+            if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
+                continue;
+            }
+            GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
+            node_needs_grad = true;
+            break;
+        }
+        if (!node_needs_grad) {
+            continue;
+        }
+
+        // inplace operations are currently not supported
+        GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
+            node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
+
+        const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+        GGML_ASSERT(igrad != GGML_HASHSET_FULL);
+        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
+        if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
+            cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
+            cgraph->grads[igrad]     = cgraph->grad_accs[igrad];
+            ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
+        }
+        grads_needed[igrad] = true;
+    }
+
+    for (int i = n_nodes_f - 1; i >= 0; --i) {
+        // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
+        // use allocator to automatically make inplace operations
+        ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
+    }
+
+    free(grads_needed);
+}
+
+static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
+    void * ptr = *p;
+    ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
+    *p = (void *) ((char *) ptr + size);
+    return ptr;
+}
+
+static size_t ggml_graph_nbytes(size_t size, bool grads) {
+    size_t hash_size = ggml_hash_size(size * 2);
+    void * p = 0;
+    incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
+    incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
+    incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
+    incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
+    if (grads) {
+        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
+        incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
+    }
+    incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
+
+    size_t nbytes = (size_t) p;
+    return nbytes;
+}
+
+size_t ggml_graph_overhead_custom(size_t size, bool grads) {
+    return GGML_OBJECT_SIZE + GGML_PAD(ggml_graph_nbytes(size, grads), GGML_MEM_ALIGN);
+}
+
+size_t ggml_graph_overhead(void) {
+    return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) {
+    const size_t obj_size = ggml_graph_nbytes(size, grads);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size);
+    struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs);
+
+    // the size of the hash table is doubled since it needs to hold both nodes and leafs
+    size_t hash_size = ggml_hash_size(size * 2);
+
+    void * p = cgraph + 1;
+
+    struct ggml_tensor ** nodes_ptr     =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** leafs_ptr     =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** hash_keys_ptr =         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** grads_ptr     = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+    struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+
+    ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
+
+    // check that we allocated the correct amount of memory
+    assert(obj_size == (size_t)((char *)p - (char *)cgraph));
+
+    *cgraph = (struct ggml_cgraph) {
+        /*.size         =*/ size,
+        /*.n_nodes      =*/ 0,
+        /*.n_leafs      =*/ 0,
+        /*.nodes        =*/ nodes_ptr,
+        /*.grads        =*/ grads_ptr,
+        /*.grad_accs    =*/ grad_accs_ptr,
+        /*.leafs        =*/ leafs_ptr,
+        /*.hash_table   =*/ { hash_size, hash_used, hash_keys_ptr },
+        /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
+    };
+
+    ggml_hash_set_reset(&cgraph->visited_hash_set);
+    if (grads) {
+        memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));
+        memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
+    }
+
+    return cgraph;
+}
+
+struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
+    return ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, false);
+}
+
+struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
+    struct ggml_cgraph cgraph = {
+        /*.size             =*/ 0,
+        /*.n_nodes          =*/ i1 - i0,
+        /*.n_leafs          =*/ 0,
+        /*.nodes            =*/ cgraph0->nodes + i0,
+        /*.grads            =*/ NULL, // gradients would need visited_hash_set
+        /*.grad_accs        =*/ NULL,
+        /*.leafs            =*/ NULL,
+        /*.visited_hash_set =*/ { 0, NULL, NULL },
+        /*.order            =*/ cgraph0->order,
+    };
+
+    return cgraph;
+}
+
+void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
+    GGML_ASSERT(dst->size >= src->n_leafs);
+    GGML_ASSERT(dst->size >= src->n_nodes);
+    GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size);
+
+    dst->n_leafs = src->n_leafs;
+    dst->n_nodes = src->n_nodes;
+    dst->order   = src->order;
+
+    for (int i = 0; i < src->n_leafs; ++i) {
+        dst->leafs[i] = src->leafs[i];
+    }
+
+    for (int i = 0; i < src->n_nodes; ++i) {
+        dst->nodes[i] = src->nodes[i];
+    }
+
+    for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
+        // copy all hashset keys (tensors) that are in use
+        if (ggml_bitset_get(src->visited_hash_set.used, i)) {
+            ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
+        }
+    }
+
+    if (dst->grads) {
+        memset(dst->grads,     0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
+        memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
+    }
+    if (src->grads) {
+        GGML_ASSERT(dst->grads     != NULL);
+        GGML_ASSERT(dst->grad_accs != NULL);
+        for (int i = 0; i < src->n_nodes; ++i) {
+            const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
+            const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
+
+            GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
+            GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
+            GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
+            GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
+
+            dst->grads[igrad_dst]     = src->grads[igrad_src];
+            dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
+        }
+    }
+}
+
+struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
+    struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL);
+    ggml_graph_cpy(cgraph, result);
+    return result;
+}
+
+struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
+    if (ggml_is_empty(tensor)) {
+        return tensor;
+    }
+    if (tensor->buffer) {
+        ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
+    } else {
+        GGML_ASSERT(tensor->data);
+        memset(tensor->data, 0, ggml_nbytes(tensor));
+    }
+    return tensor;
+}
+
+void ggml_graph_reset(struct ggml_cgraph * cgraph) {
+    GGML_ASSERT(cgraph->grads != NULL);
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node     = cgraph->nodes[i];
+        struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
+
+        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
+            // clear momenta
+            ggml_set_zero(node->src[2]);
+            ggml_set_zero(node->src[3]);
+        }
+
+        // initial gradients of loss should be 1, 0 otherwise
+        if (grad_acc) {
+            if (node->flags & GGML_TENSOR_FLAG_LOSS) {
+                GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
+                GGML_ASSERT(ggml_is_scalar(grad_acc));
+
+                const float onef = 1.0f;
+                if (grad_acc->buffer) {
+                    ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
+                } else {
+                    GGML_ASSERT(grad_acc->data);
+                    *((float *) grad_acc->data) = onef;
+                }
+            } else {
+                ggml_set_zero(grad_acc);
+            }
+        }
+    }
+}
+
+void ggml_graph_clear(struct ggml_cgraph * cgraph) {
+    cgraph->n_leafs = 0;
+    cgraph->n_nodes = 0;
+    ggml_hash_set_reset(&cgraph->visited_hash_set);
+}
+
+int ggml_graph_size(struct ggml_cgraph * cgraph) {
+    return cgraph->size;
+}
+
+struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) {
+    if (i < 0) {
+        GGML_ASSERT(cgraph->n_nodes + i >= 0);
+        return cgraph->nodes[cgraph->n_nodes + i];
+    }
+
+    GGML_ASSERT(i < cgraph->n_nodes);
+    return cgraph->nodes[i];
+}
+
+struct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) {
+    return cgraph->nodes;
+}
+
+int ggml_graph_n_nodes(struct ggml_cgraph * cgraph) {
+    return cgraph->n_nodes;
+}
+
+void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
+    GGML_ASSERT(cgraph->size > cgraph->n_nodes);
+    cgraph->nodes[cgraph->n_nodes] = tensor;
+    cgraph->n_nodes++;
+}
+
+struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        struct ggml_tensor * leaf = cgraph->leafs[i];
+
+        if (strcmp(leaf->name, name) == 0) {
+            return leaf;
+        }
+    }
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        if (strcmp(node->name, name) == 0) {
+            return node;
+        }
+    }
+
+    return NULL;
+}
+
+struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL;
+}
+
+struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
+    return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL;
+}
+
+void ggml_graph_print(const struct ggml_cgraph * cgraph) {
+    GGML_LOG_INFO("=== GRAPH ===\n");
+
+    GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes);
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
+                i,
+                node->ne[0], node->ne[1], node->ne[2],
+                ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
+                      ggml_graph_get_grad(cgraph, node) ? "g" : " ");
+    }
+
+    GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
+    for (int i = 0; i < cgraph->n_leafs; i++) {
+        struct ggml_tensor * node = cgraph->leafs[i];
+
+        GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
+                i,
+                node->ne[0], node->ne[1],
+                ggml_op_name(node->op),
+                ggml_get_name(node));
+    }
+
+    GGML_LOG_INFO("========================================\n");
+}
+
+// check if node is part of the graph
+static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    if (cgraph == NULL) {
+        return true;
+    }
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        if (cgraph->nodes[i] == node) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * parent = cgraph->nodes[i];
+        struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
+
+        if (grad == node) {
+            return parent;
+        }
+    }
+
+    return NULL;
+}
+
+static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {
+    struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node);
+    struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent);
+    fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n",
+            gparent0 ? (void *) gparent0 : (void *) parent,
+            gparent0 ? "g" : "x",
+            gparent ? (void *) gparent : (void *) node,
+            gparent ? "g" : "x",
+            gparent ? "empty" : "vee",
+            gparent ? "dashed" : "solid",
+            label);
+}
+
+static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label)  {
+    fprintf(fp, "  \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n",
+            (void *) parent, "x",
+            (void *) node, "x",
+            label);
+}
+
+void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) {
+    char color[16];
+
+    FILE * fp = ggml_fopen(filename, "w");
+    GGML_ASSERT(fp);
+
+    fprintf(fp, "digraph G {\n");
+    fprintf(fp, "  newrank = true;\n");
+    fprintf(fp, "  rankdir = TB;\n");
+
+    for (int i = 0; i < gb->n_nodes; i++) {
+        struct ggml_tensor * node = gb->nodes[i];
+        struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
+
+        if (ggml_graph_get_parent(gb, node) != NULL) {
+            continue;
+        }
+
+        if (node->flags & GGML_TENSOR_FLAG_PARAM) {
+            snprintf(color, sizeof(color), "yellow");
+        } else if (grad) {
+            if (ggml_graph_find(gf, node)) {
+                snprintf(color, sizeof(color), "green");
+            } else {
+                snprintf(color, sizeof(color), "lightblue");
+            }
+        } else {
+            snprintf(color, sizeof(color), "white");
+        }
+
+        fprintf(fp, "  \"%p\" [ "
+                    "style = filled; fillcolor = %s; shape = record; "
+                    "label=\"",
+                (void *) node, color);
+
+        if (strlen(node->name) > 0) {
+            fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
+        } else {
+            fprintf(fp, "(%s)|", ggml_type_name(node->type));
+        }
+
+        if (ggml_is_matrix(node)) {
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], ggml_op_symbol(node->op));
+        } else {
+            fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
+        }
+
+        if (grad) {
+            fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(grad->op));
+        } else {
+            fprintf(fp, "\"; ]\n");
+        }
+    }
+
+    for (int i = 0; i < gb->n_leafs; i++) {
+        struct ggml_tensor * node = gb->leafs[i];
+
+        snprintf(color, sizeof(color), "pink");
+
+        fprintf(fp, "  \"%p\" [ "
+                    "style = filled; fillcolor = %s; shape = record; "
+                    "label=\"",
+                (void *) node, color);
+
+        if (strlen(node->name) > 0) {
+            fprintf(fp, "%s (%s)|", node->name, ggml_type_name(node->type));
+        } else {
+            fprintf(fp, "(%s)|", ggml_type_name(node->type));
+        }
+
+        fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
+        if (ggml_nelements(node) < 5 && node->data != NULL) {
+            fprintf(fp, " | (");
+            for (int j = 0; j < ggml_nelements(node); j++) {
+                // FIXME: use ggml-backend to obtain the tensor data
+                //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
+                //    fprintf(fp, "%d", ggml_get_i32_1d(node, j));
+                //}
+                //else if (node->type == GGML_TYPE_F32 ||
+                //         node->type == GGML_TYPE_F16 ||
+                //         node->type == GGML_TYPE_BF16) {
+                //    fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j));
+                //}
+                //else
+                {
+                    fprintf(fp, "#");
+                }
+                if (j < ggml_nelements(node) - 1) {
+                    fprintf(fp, ", ");
+                }
+            }
+            fprintf(fp, ")");
+        }
+        fprintf(fp, "\"; ]\n");
+    }
+
+    for (int i = 0; i < gb->n_nodes; i++) {
+        struct ggml_tensor * node = gb->nodes[i];
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            if (node->src[j]) {
+                char label[16];
+                snprintf(label, sizeof(label), "src %d", j);
+                ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
+            }
+        }
+    }
+
+    for (int i = 0; i < gb->n_leafs; i++) {
+        struct ggml_tensor * node = gb->leafs[i];
+
+        for (int j = 0; j < GGML_MAX_SRC; j++) {
+            if (node->src[j]) {
+                char label[16];
+                snprintf(label, sizeof(label), "src %d", j);
+                ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
+            }
+        }
+    }
+
+    fprintf(fp, "}\n");
+
+    fclose(fp);
+
+    GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_set_input(struct ggml_tensor * tensor) {
+    tensor->flags |= GGML_TENSOR_FLAG_INPUT;
+}
+
+void ggml_set_output(struct ggml_tensor * tensor) {
+    tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
+}
+
+void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
+    GGML_UNUSED(ctx); // TODO: remove this parameter
+    tensor->flags |= GGML_TENSOR_FLAG_PARAM;
+}
+
+void ggml_set_loss(struct ggml_tensor * tensor) {
+    GGML_ASSERT(ggml_is_scalar(tensor));
+    GGML_ASSERT(tensor->type == GGML_TYPE_F32);
+    tensor->flags |= GGML_TENSOR_FLAG_LOSS;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+void ggml_quantize_init(enum ggml_type type) {
+    ggml_critical_section_start();
+
+    switch (type) {
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:   iq2xs_init_impl(type); break;
+        case GGML_TYPE_IQ3_XXS: iq3xs_init_impl(256); break;
+        case GGML_TYPE_IQ3_S:   iq3xs_init_impl(512); break;
+        default: // nothing
+            break;
+    }
+
+    ggml_critical_section_end();
+}
+
+void ggml_quantize_free(void) {
+    ggml_critical_section_start();
+
+    iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
+    iq2xs_free_impl(GGML_TYPE_IQ2_XS);
+    iq2xs_free_impl(GGML_TYPE_IQ1_S);
+    iq3xs_free_impl(256);
+
+    ggml_critical_section_end();
+}
+
+bool ggml_quantize_requires_imatrix(enum ggml_type type) {
+    return
+        type == GGML_TYPE_IQ2_XXS ||
+        type == GGML_TYPE_IQ2_XS  ||
+        type == GGML_TYPE_IQ1_S;//   ||
+        //type == GGML_TYPE_IQ1_M;
+}
+
+size_t ggml_quantize_chunk(
+        enum ggml_type   type,
+           const float * src,
+                  void * dst,
+               int64_t   start,
+               int64_t   nrows,
+               int64_t   n_per_row,
+           const float * imatrix) {
+    const int64_t n = (int64_t) nrows * n_per_row;
+
+    if (ggml_quantize_requires_imatrix(type)) {
+        GGML_ASSERT(imatrix != NULL);
+    }
+
+    GGML_ASSERT(start % type_traits[type].blck_size == 0);
+    GGML_ASSERT(start % n_per_row == 0);
+
+    ggml_quantize_init(type); // this is noop if already initialized
+
+    const size_t start_row = start / n_per_row;
+    const size_t row_size  = ggml_row_size(type, n_per_row);
+
+    size_t result = 0;
+
+    switch (type) {
+        case GGML_TYPE_Q4_0:    result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q4_1:    result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q5_0:    result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q5_1:    result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q8_0:    result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q2_K:    result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q3_K:    result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q4_K:    result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q5_K:    result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_Q6_K:    result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_TQ1_0:   result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_TQ2_0:   result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ2_XS:  result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ3_S:   result = quantize_iq3_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ2_S:   result = quantize_iq2_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ1_S:   result = quantize_iq1_s  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ1_M:   result = quantize_iq1_m  (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ4_NL:  result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_IQ4_XS:  result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_F16:
+            {
+                size_t elemsize = sizeof(ggml_fp16_t);
+                ggml_fp32_to_fp16_row(src + start, (ggml_fp16_t *)dst + start, n);
+                result = n * elemsize;
+            } break;
+        case GGML_TYPE_BF16:
+            {
+                size_t elemsize = sizeof(ggml_bf16_t);
+                ggml_fp32_to_bf16_row_ref(src + start, (ggml_bf16_t *)dst + start, n);
+                result = n * elemsize;
+            } break;
+        case GGML_TYPE_F32:
+            {
+                size_t elemsize = sizeof(float);
+                result = n * elemsize;
+                memcpy((uint8_t *)dst + start * elemsize, src + start, result);
+            } break;
+        default:
+            assert(false);
+    }
+
+    GGML_ASSERT(result == nrows * row_size);
+
+    return result;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct gguf_str {
+    uint64_t n;  // GGUFv2
+    char * data;
+};
+
+static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
+    [GGUF_TYPE_UINT8]   = sizeof(uint8_t),
+    [GGUF_TYPE_INT8]    = sizeof(int8_t),
+    [GGUF_TYPE_UINT16]  = sizeof(uint16_t),
+    [GGUF_TYPE_INT16]   = sizeof(int16_t),
+    [GGUF_TYPE_UINT32]  = sizeof(uint32_t),
+    [GGUF_TYPE_INT32]   = sizeof(int32_t),
+    [GGUF_TYPE_FLOAT32] = sizeof(float),
+    [GGUF_TYPE_BOOL]    = sizeof(bool),
+    [GGUF_TYPE_STRING]  = sizeof(struct gguf_str),
+    [GGUF_TYPE_UINT64]  = sizeof(uint64_t),
+    [GGUF_TYPE_INT64]   = sizeof(int64_t),
+    [GGUF_TYPE_FLOAT64] = sizeof(double),
+    [GGUF_TYPE_ARRAY]   = 0, // undefined
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
+    [GGUF_TYPE_UINT8]   = "u8",
+    [GGUF_TYPE_INT8]    = "i8",
+    [GGUF_TYPE_UINT16]  = "u16",
+    [GGUF_TYPE_INT16]   = "i16",
+    [GGUF_TYPE_UINT32]  = "u32",
+    [GGUF_TYPE_INT32]   = "i32",
+    [GGUF_TYPE_FLOAT32] = "f32",
+    [GGUF_TYPE_BOOL]    = "bool",
+    [GGUF_TYPE_STRING]  = "str",
+    [GGUF_TYPE_ARRAY]   = "arr",
+    [GGUF_TYPE_UINT64]  = "u64",
+    [GGUF_TYPE_INT64]   = "i64",
+    [GGUF_TYPE_FLOAT64] = "f64",
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+union gguf_value {
+    uint8_t  uint8;
+    int8_t   int8;
+    uint16_t uint16;
+    int16_t  int16;
+    uint32_t uint32;
+    int32_t  int32;
+    float    float32;
+    uint64_t uint64;
+    int64_t  int64;
+    double   float64;
+    bool     bool_;
+
+    struct gguf_str str;
+
+    struct {
+        enum gguf_type type;
+
+        uint64_t n;  // GGUFv2
+        void * data;
+    } arr;
+};
+
+struct gguf_kv {
+    struct gguf_str key;
+
+    enum  gguf_type  type;
+    union gguf_value value;
+};
+
+struct gguf_header {
+    char magic[4];
+
+    uint32_t version;
+    uint64_t n_tensors; // GGUFv2
+    uint64_t n_kv;      // GGUFv2
+};
+
+struct gguf_tensor_info {
+    struct gguf_str name;
+
+    uint32_t n_dims;
+    uint64_t ne[GGML_MAX_DIMS];
+
+    enum ggml_type type;
+
+    uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
+
+    // for writing API
+    const void * data;
+    size_t size;
+};
+
+struct gguf_context {
+    struct gguf_header header;
+
+    struct gguf_kv          * kv;
+    struct gguf_tensor_info * infos;
+
+    size_t alignment;
+    size_t offset;    // offset of `data` from beginning of file
+    size_t size;      // size of `data` in bytes
+
+    //uint8_t * padding;
+    void * data;
+};
+
+size_t gguf_type_size(enum gguf_type type) {
+    GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
+    return GGUF_TYPE_SIZE[type];
+}
+
+static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
+    if (info->n_dims > GGML_MAX_DIMS) {
+        fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
+        return false;
+    }
+
+    if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
+        fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
+        return false;
+    }
+
+    if (strlen(info->name.data) >= GGML_MAX_NAME) {
+        fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
+        return false;
+    }
+
+    for (uint32_t i = 0; i < info->n_dims; ++i) {
+        if (info->ne[i] <= 0) {
+            fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
+            return false;
+        }
+    }
+
+    // prevent overflow for total number of elements
+    if (INT64_MAX/info->ne[1] <= info->ne[0]) {
+        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
+        return false;
+    }
+
+    if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
+        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
+        return false;
+    }
+
+    if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
+        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
+        return false;
+    }
+
+    return true;
+}
+
+static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
+    const size_t n = fread(dst, 1, size, file);
+    *offset += n;
+    return n == size;
+}
+
+static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
+    p->n    = 0;
+    p->data = NULL;
+
+    bool ok = true;
+
+    ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
+
+    // early exit if string length is invalid, prevents from integer overflow
+    if (p->n == SIZE_MAX) {
+        fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
+        return false;
+    }
+
+    p->data = calloc(p->n + 1, 1);
+    if (!p->data) {
+        fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
+        return false;
+    }
+
+    ok = ok && gguf_fread_el(file,  p->data, p->n, offset);
+
+    return ok;
+}
+
+static void gguf_free_kv(struct gguf_kv * kv) {
+    if (kv->key.data) {
+        GGML_FREE(kv->key.data);
+    }
+
+    if (kv->type == GGUF_TYPE_STRING) {
+        if (kv->value.str.data) {
+            GGML_FREE(kv->value.str.data);
+        }
+    }
+
+    if (kv->type == GGUF_TYPE_ARRAY) {
+        if (kv->value.arr.data) {
+            if (kv->value.arr.type == GGUF_TYPE_STRING) {
+                for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
+                    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
+                    if (str->data) {
+                        GGML_FREE(str->data);
+                    }
+                }
+            }
+            GGML_FREE(kv->value.arr.data);
+        }
+    }
+}
+
+struct gguf_context * gguf_init_empty(void) {
+    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
+    if (!ctx) {
+        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
+        return NULL;
+    }
+
+    memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
+    ctx->header.version   = GGUF_VERSION;
+    ctx->header.n_tensors = 0;
+    ctx->header.n_kv      = 0;
+
+    ctx->kv    = NULL;
+    ctx->infos = NULL;
+
+    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
+    ctx->offset    = 0;
+    ctx->size      = 0;
+
+    ctx->data = NULL;
+
+    return ctx;
+}
+
+struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
+    // offset from start of file
+    size_t offset = 0;
+
+    char magic[4];
+
+    // check the magic before making allocations
+    {
+        gguf_fread_el(file, &magic, sizeof(magic), &offset);
+
+        for (uint32_t i = 0; i < sizeof(magic); i++) {
+            if (magic[i] != GGUF_MAGIC[i]) {
+                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
+                return NULL;
+            }
+        }
+    }
+
+    bool ok = true;
+
+    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
+    if (!ctx) {
+        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
+        return NULL;
+    }
+
+    // read the header
+    {
+        strncpy(ctx->header.magic, magic, 4);
+
+        ctx->kv    = NULL;
+        ctx->infos = NULL;
+        ctx->data  = NULL;
+
+        ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
+        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
+        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
+
+        if (ctx->header.version == 1) {
+            fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
+            gguf_free(ctx);
+            return NULL;
+        }
+
+        // sanity-checks to prevent from integer/buffer overflows
+
+        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
+        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
+        ok = ok && (ctx->header.n_kv      < (SIZE_MAX/2)/sizeof(struct gguf_kv));
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read header\n", __func__);
+            gguf_free(ctx);
+            return NULL;
+        }
+    }
+
+    // read the kv pairs
+    {
+        const uint64_t n_kv = ctx->header.n_kv;
+
+        if (n_kv > 0) {
+            ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
+            if (!ctx->kv) {
+                fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
+                gguf_free(ctx);
+                return NULL;
+            }
+        }
+
+        for (uint64_t i = 0; i < n_kv; ++i) {
+            struct gguf_kv * kv = &ctx->kv[i];
+
+            //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
+
+            ok = ok && gguf_fread_str(file, &kv->key,                    &offset);
+            ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
+
+            //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
+
+            switch (kv->type) {
+                case GGUF_TYPE_UINT8:   ok = ok && gguf_fread_el (file, &kv->value.uint8,   sizeof(kv->value.uint8),   &offset); break;
+                case GGUF_TYPE_INT8:    ok = ok && gguf_fread_el (file, &kv->value.int8,    sizeof(kv->value.int8),    &offset); break;
+                case GGUF_TYPE_UINT16:  ok = ok && gguf_fread_el (file, &kv->value.uint16,  sizeof(kv->value.uint16),  &offset); break;
+                case GGUF_TYPE_INT16:   ok = ok && gguf_fread_el (file, &kv->value.int16,   sizeof(kv->value.int16),   &offset); break;
+                case GGUF_TYPE_UINT32:  ok = ok && gguf_fread_el (file, &kv->value.uint32,  sizeof(kv->value.uint32),  &offset); break;
+                case GGUF_TYPE_INT32:   ok = ok && gguf_fread_el (file, &kv->value.int32,   sizeof(kv->value.int32),   &offset); break;
+                case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
+                case GGUF_TYPE_UINT64:  ok = ok && gguf_fread_el (file, &kv->value.uint64,  sizeof(kv->value.uint64),  &offset); break;
+                case GGUF_TYPE_INT64:   ok = ok && gguf_fread_el (file, &kv->value.int64,   sizeof(kv->value.int64),   &offset); break;
+                case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
+                case GGUF_TYPE_BOOL:    ok = ok && gguf_fread_el (file, &kv->value.bool_,   sizeof(kv->value.bool_),   &offset); break;
+                case GGUF_TYPE_STRING:  ok = ok && gguf_fread_str(file, &kv->value.str,                                &offset); break;
+                case GGUF_TYPE_ARRAY:
+                    {
+                        ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
+                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n),    &offset);
+
+                        switch (kv->value.arr.type) {
+                            case GGUF_TYPE_UINT8:
+                            case GGUF_TYPE_INT8:
+                            case GGUF_TYPE_UINT16:
+                            case GGUF_TYPE_INT16:
+                            case GGUF_TYPE_UINT32:
+                            case GGUF_TYPE_INT32:
+                            case GGUF_TYPE_FLOAT32:
+                            case GGUF_TYPE_UINT64:
+                            case GGUF_TYPE_INT64:
+                            case GGUF_TYPE_FLOAT64:
+                            case GGUF_TYPE_BOOL:
+                                {
+                                    // prevent from integer overflow in the malloc below
+                                    if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
+                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
+                                        gguf_free(ctx);
+                                        return NULL;
+                                    }
+
+                                    kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
+                                    if (!kv->value.arr.data) {
+                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
+                                        gguf_free(ctx);
+                                        return NULL;
+                                    }
+
+                                    ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
+                                } break;
+                            case GGUF_TYPE_STRING:
+                                {
+                                    // prevent from integer overflow in the malloc below
+                                    if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
+                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
+                                        gguf_free(ctx);
+                                        return NULL;
+                                    }
+
+                                    kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
+                                    if (!kv->value.arr.data) {
+                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
+                                        gguf_free(ctx);
+                                        return NULL;
+                                    }
+
+                                    for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
+                                        ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
+                                    }
+                                } break;
+                            case GGUF_TYPE_ARRAY:
+                            default:
+                                {
+                                    fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
+                                    ok = false;
+                                } break;
+                        }
+                    } break;
+                default:
+                    {
+                        fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
+                        ok = false;
+                    } break;
+            }
+
+            if (!ok) {
+                break;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
+            gguf_free(ctx);
+            return NULL;
+        }
+    }
+
+    // read the tensor infos
+    if (ctx->header.n_tensors > 0) {
+        ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
+        if (!ctx->infos) {
+            fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
+            gguf_free(ctx);
+            return NULL;
+        }
+
+        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+            struct gguf_tensor_info * info = &ctx->infos[i];
+
+            for (int j = 0; j < GGML_MAX_DIMS; ++j) {
+                info->ne[j] = 1;
+            }
+
+            ok = ok && gguf_fread_str(file, &info->name,                          &offset);
+            ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
+
+            ok = ok && (info->n_dims <= GGML_MAX_DIMS);
+
+            for (uint32_t j = 0; j < info->n_dims; ++j) {
+                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
+            }
+
+            ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
+            ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
+
+            ok = ok && gguf_tensor_info_sanitize(info);
+
+            // make sure there is no duplicated tensor names
+            for (uint64_t j = 0; j < i && ok; ++j) {
+                if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
+                    fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
+                    ok = false;
+                }
+            }
+
+            if (!ok) {
+                fprintf(stderr, "%s: failed to read tensor info\n", __func__);
+                gguf_free(ctx);
+                return NULL;
+            }
+        }
+    }
+
+    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
+
+    int alignment_idx = gguf_find_key(ctx, "general.alignment");
+    if (alignment_idx != -1) {
+        ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
+    }
+
+    // we require the data section to be aligned, so take into account any padding
+    {
+        const size_t offset_pad = offset % ctx->alignment;
+
+        if (offset_pad != 0) {
+            offset += ctx->alignment - offset_pad;
+            fseek(file, offset, SEEK_SET);
+        }
+    }
+
+    // store the current file offset - this is where the data section starts
+    ctx->offset = offset;
+
+    // compute the total size of the data section, taking into account the alignment
+    {
+        ctx->size = 0;
+        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+            struct gguf_tensor_info * info = &ctx->infos[i];
+
+            const int64_t ne =
+                (int64_t) info->ne[0] *
+                (int64_t) info->ne[1] *
+                (int64_t) info->ne[2] *
+                (int64_t) info->ne[3];
+
+            if (ggml_blck_size(info->type) == 0 ) {
+                // this tensor type support have been removed:
+                fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
+                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
+                gguf_free(ctx);
+                return NULL;
+            }
+
+            if (ne % ggml_blck_size(info->type) != 0) {
+                fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
+                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
+                gguf_free(ctx);
+                return NULL;
+            }
+
+            const size_t size_cur = ggml_row_size(info->type, ne);
+
+            ctx->size += GGML_PAD(size_cur, ctx->alignment);
+        }
+    }
+
+    // load the tensor data only if requested
+    if (params.ctx != NULL) {
+        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
+        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
+        // the ggml_tensor structs to the appropriate locations in the binary blob
+
+        // compute the exact size needed for the new ggml_context
+        const size_t mem_size =
+            params.no_alloc ?
+            (ctx->header.n_tensors    )*ggml_tensor_overhead() :
+            (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+
+        struct ggml_init_params pdata = {
+            .mem_size   = mem_size,
+            .mem_buffer = NULL,
+            .no_alloc   = params.no_alloc,
+        };
+
+        *params.ctx = ggml_init(pdata);
+        if (*params.ctx == NULL) {
+            fprintf(stderr, "%s: failed to initialize context\n", __func__);
+            gguf_free(ctx);
+            return NULL;
+        }
+
+        struct ggml_context * ctx_data = *params.ctx;
+
+        struct ggml_tensor * data = NULL;
+
+        if (!params.no_alloc) {
+            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
+
+            ok = ok && data != NULL;
+
+            // read the binary blob with the tensor data
+            ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
+
+            if (!ok) {
+                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
+                ggml_free(ctx_data);
+                gguf_free(ctx);
+                return NULL;
+            }
+
+            ctx->data = data->data;
+        }
+
+        ggml_set_no_alloc(ctx_data, true);
+
+        // create the tensors
+        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+            const int64_t ne[GGML_MAX_DIMS] = {
+                ctx->infos[i].ne[0],
+                ctx->infos[i].ne[1],
+                ctx->infos[i].ne[2],
+                ctx->infos[i].ne[3],
+            };
+
+            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
+
+            ok = ok && cur != NULL;
+
+            if (!ok) {
+                break;
+            }
+
+            ggml_set_name(cur, ctx->infos[i].name.data);
+
+            // point the data member to the appropriate location in the binary blob using the tensor infos
+            if (!params.no_alloc) {
+              //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
+                cur->data = (char *) data->data + ctx->infos[i].offset;               // offset from data
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
+            ggml_free(ctx_data);
+            gguf_free(ctx);
+            return NULL;
+        }
+
+        ggml_set_no_alloc(ctx_data, params.no_alloc);
+    }
+
+    return ctx;
+}
+
+struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
+    FILE * file = ggml_fopen(fname, "rb");
+    if (!file) {
+        fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
+        return NULL;
+    }
+
+    struct gguf_context * result = gguf_init_from_file_impl(file, params);
+    fclose(file);
+    return result;
+}
+
+void gguf_free(struct gguf_context * ctx) {
+    if (ctx == NULL) {
+        return;
+    }
+
+    if (ctx->kv) {
+        // free string memory - not great..
+        for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
+            gguf_free_kv(&ctx->kv[i]);
+        }
+
+        GGML_FREE(ctx->kv);
+    }
+
+    if (ctx->infos) {
+        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
+            struct gguf_tensor_info * info = &ctx->infos[i];
+
+            if (info->name.data) {
+                GGML_FREE(info->name.data);
+            }
+        }
+
+        GGML_FREE(ctx->infos);
+    }
+
+    GGML_FREE(ctx);
+}
+
+const char * gguf_type_name(enum gguf_type type) {
+    return GGUF_TYPE_NAME[type];
+}
+
+int gguf_get_version(const struct gguf_context * ctx) {
+    return ctx->header.version;
+}
+
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
+    return ctx->alignment;
+}
+
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
+    return ctx->offset;
+}
+
+void * gguf_get_data(const struct gguf_context * ctx) {
+    return ctx->data;
+}
+
+int gguf_get_n_kv(const struct gguf_context * ctx) {
+    return ctx->header.n_kv;
+}
+
+int gguf_find_key(const struct gguf_context * ctx, const char * key) {
+    // return -1 if key not found
+    int keyfound = -1;
+
+    const int n_kv = gguf_get_n_kv(ctx);
+
+    for (int i = 0; i < n_kv; ++i) {
+        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
+            keyfound = i;
+            break;
+        }
+    }
+
+    return keyfound;
+}
+
+const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].key.data;
+}
+
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].type;
+}
+
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    return ctx->kv[key_id].value.arr.type;
+}
+
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    return ctx->kv[key_id].value.arr.data;
+}
+
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    struct gguf_kv * kv = &ctx->kv[key_id];
+    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
+    return str->data;
+}
+
+int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
+    return ctx->kv[key_id].value.arr.n;
+}
+
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
+    return ctx->kv[key_id].value.uint8;
+}
+
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
+    return ctx->kv[key_id].value.int8;
+}
+
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
+    return ctx->kv[key_id].value.uint16;
+}
+
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
+    return ctx->kv[key_id].value.int16;
+}
+
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
+    return ctx->kv[key_id].value.uint32;
+}
+
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
+    return ctx->kv[key_id].value.int32;
+}
+
+float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
+    return ctx->kv[key_id].value.float32;
+}
+
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
+    return ctx->kv[key_id].value.uint64;
+}
+
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
+    return ctx->kv[key_id].value.int64;
+}
+
+double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
+    return ctx->kv[key_id].value.float64;
+}
+
+bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
+    return ctx->kv[key_id].value.bool_;
+}
+
+const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
+    return ctx->kv[key_id].value.str.data;
+}
+
+const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
+    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
+    return &ctx->kv[key_id].value;
+}
+
+int gguf_get_n_tensors(const struct gguf_context * ctx) {
+    return ctx->header.n_tensors;
+}
+
+int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
+    // return -1 if tensor not found
+    int tensorfound = -1;
+
+    const int n_tensors = gguf_get_n_tensors(ctx);
+
+    for (int i = 0; i < n_tensors; ++i) {
+        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
+            tensorfound = i;
+            break;
+        }
+    }
+
+    return tensorfound;
+}
+
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
+    return ctx->infos[i].offset;
+}
+
+char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
+    return ctx->infos[i].name.data;
+}
+
+enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
+    return ctx->infos[i].type;
+}
+
+// returns the index
+static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
+    const int idx = gguf_find_key(ctx, key);
+    if (idx >= 0) {
+        return idx;
+    }
+
+    const int n_kv = gguf_get_n_kv(ctx);
+
+    ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
+    ctx->kv[n_kv].key.n    = strlen(key);
+    ctx->kv[n_kv].key.data = strdup(key);
+    ctx->header.n_kv++;
+
+    return n_kv;
+}
+
+void gguf_remove_key(struct gguf_context * ctx, const char * key) {
+    const int idx = gguf_find_key(ctx, key);
+    if (idx >= 0) {
+        const int n_kv = gguf_get_n_kv(ctx);
+        gguf_free_kv(&ctx->kv[idx]);
+        for (int i = idx; i < n_kv-1; ++i) {
+            ctx->kv[i] = ctx->kv[i+1];
+        }
+        ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
+        ctx->header.n_kv--;
+    }
+}
+
+void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type        = GGUF_TYPE_UINT8;
+    ctx->kv[idx].value.uint8 = val;
+}
+
+void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type       = GGUF_TYPE_INT8;
+    ctx->kv[idx].value.int8 = val;
+}
+
+void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type         = GGUF_TYPE_UINT16;
+    ctx->kv[idx].value.uint16 = val;
+}
+
+void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type        = GGUF_TYPE_INT16;
+    ctx->kv[idx].value.int16 = val;
+}
+
+void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type         = GGUF_TYPE_UINT32;
+    ctx->kv[idx].value.uint32 = val;
+}
+
+void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type        = GGUF_TYPE_INT32;
+    ctx->kv[idx].value.int32 = val;
+}
+
+void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type          = GGUF_TYPE_FLOAT32;
+    ctx->kv[idx].value.float32 = val;
+}
+
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type         = GGUF_TYPE_UINT64;
+    ctx->kv[idx].value.uint64 = val;
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type        = GGUF_TYPE_INT64;
+    ctx->kv[idx].value.int64 = val;
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type          = GGUF_TYPE_FLOAT64;
+    ctx->kv[idx].value.float64 = val;
+}
+
+void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type        = GGUF_TYPE_BOOL;
+    ctx->kv[idx].value.bool_ = val;
+}
+
+void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type           = GGUF_TYPE_STRING;
+    ctx->kv[idx].value.str.n    = strlen(val);
+    ctx->kv[idx].value.str.data = strdup(val);
+}
+
+void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
+    ctx->kv[idx].value.arr.type = type;
+    ctx->kv[idx].value.arr.n    = n;
+    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
+    memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
+}
+
+void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
+    const int idx = gguf_get_or_add_key(ctx, key);
+
+    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
+    ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
+    ctx->kv[idx].value.arr.n    = n;
+    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
+    for (int i = 0; i < n; i++) {
+        struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
+        str->n    = strlen(data[i]);
+        str->data = strdup(data[i]);
+    }
+}
+
+// set or add KV pairs from another context
+void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
+    for (uint32_t i = 0; i < src->header.n_kv; i++) {
+        switch (src->kv[i].type) {
+            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, src->kv[i].key.data, src->kv[i].value.uint8);    break;
+            case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, src->kv[i].key.data, src->kv[i].value.int8);     break;
+            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16);   break;
+            case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16);    break;
+            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32);   break;
+            case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32);    break;
+            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32);  break;
+            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64);   break;
+            case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64);    break;
+            case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64);  break;
+            case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_);    break;
+            case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
+            case GGUF_TYPE_ARRAY:
+                {
+                    if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
+                        const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
+                        for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
+                            data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
+                        }
+                        gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
+                        GGML_FREE((void *)data);
+                    } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
+                        GGML_ABORT("nested arrays not supported");
+                    } else {
+                        gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
+                    }
+                } break;
+            default: GGML_ABORT("invalid type");
+        }
+    }
+}
+
+void gguf_add_tensor(
+             struct gguf_context * ctx,
+        const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+    if (gguf_find_tensor(ctx, tensor->name) != -1) {
+        GGML_ABORT("duplicated tensor name");
+    }
+
+    const int idx = ctx->header.n_tensors;
+    ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
+
+    ctx->infos[idx].name.n    = strlen(tensor->name);
+    ctx->infos[idx].name.data = strdup(tensor->name);
+
+    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+        ctx->infos[idx].ne[i] = 1;
+    }
+
+    ctx->infos[idx].n_dims = ggml_n_dims(tensor);
+    for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
+        ctx->infos[idx].ne[i] = tensor->ne[i];
+    }
+
+    ctx->infos[idx].type   = tensor->type;
+    ctx->infos[idx].offset = 0;
+    ctx->infos[idx].data   = tensor->data;
+    ctx->infos[idx].size   = ggml_nbytes(tensor);
+
+    if (ctx->header.n_tensors > 0) {
+        ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
+    }
+
+    ctx->header.n_tensors++;
+}
+
+void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
+    const int idx = gguf_find_tensor(ctx, name);
+    if (idx < 0) {
+        GGML_ABORT("tensor not found");
+    }
+
+    ctx->infos[idx].type = type;
+}
+
+void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
+    const int idx = gguf_find_tensor(ctx, name);
+    if (idx < 0) {
+        GGML_ABORT("tensor not found");
+    }
+
+    ctx->infos[idx].data = data;
+    ctx->infos[idx].size = size;
+
+    // update offsets
+    for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
+        ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
+    }
+}
+
+//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
+//    fwrite(&val->n,   sizeof(val->n),    1, file);
+//    fwrite(val->data, sizeof(char), val->n, file);
+//}
+//
+//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
+//    fwrite(val, sizeof(char), size, file);
+//}
+
+struct gguf_buf gguf_buf_init(size_t size) {
+    struct gguf_buf buf = {
+        /*buf.data   =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
+        /*buf.size   =*/ size,
+        /*buf.offset =*/ 0,
+    };
+
+    return buf;
+}
+
+void gguf_buf_free(struct gguf_buf buf) {
+    if (buf.data) {
+        GGML_FREE(buf.data);
+    }
+}
+
+static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
+    if (buf->offset + size > buf->size) {
+        buf->size = 1.5*(buf->offset + size);
+        if (buf->data) {
+            buf->data = realloc(buf->data, buf->size);
+        }
+    }
+}
+
+static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
+    gguf_buf_grow(buf, sizeof(val->n) + val->n);
+
+    if (buf->data) {
+        memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
+    }
+    buf->offset += sizeof(val->n);
+
+    if (buf->data) {
+        memcpy((char *) buf->data + buf->offset, val->data, val->n);
+    }
+    buf->offset += val->n;
+}
+
+static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
+    gguf_buf_grow(buf, el_size);
+
+    if (buf->data) {
+        memcpy((char *) buf->data + buf->offset, val, el_size);
+    }
+    buf->offset += el_size;
+}
+
+void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
+    // write header
+    gguf_bwrite_el(buf, &ctx->header.magic,     sizeof(ctx->header.magic));
+    gguf_bwrite_el(buf, &ctx->header.version,   sizeof(ctx->header.version));
+    gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
+    gguf_bwrite_el(buf, &ctx->header.n_kv,      sizeof(ctx->header.n_kv));
+
+    // write key-value pairs
+    for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
+        struct gguf_kv * kv = &ctx->kv[i];
+
+        gguf_bwrite_str(buf, &kv->key);
+        gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
+
+        switch (kv->type) {
+            case GGUF_TYPE_UINT8:   gguf_bwrite_el( buf, &kv->value.uint8,   sizeof(kv->value.uint8)  ); break;
+            case GGUF_TYPE_INT8:    gguf_bwrite_el (buf, &kv->value.int8,    sizeof(kv->value.int8)   ); break;
+            case GGUF_TYPE_UINT16:  gguf_bwrite_el (buf, &kv->value.uint16,  sizeof(kv->value.uint16) ); break;
+            case GGUF_TYPE_INT16:   gguf_bwrite_el (buf, &kv->value.int16,   sizeof(kv->value.int16)  ); break;
+            case GGUF_TYPE_UINT32:  gguf_bwrite_el (buf, &kv->value.uint32,  sizeof(kv->value.uint32) ); break;
+            case GGUF_TYPE_INT32:   gguf_bwrite_el (buf, &kv->value.int32,   sizeof(kv->value.int32)  ); break;
+            case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
+            case GGUF_TYPE_UINT64:  gguf_bwrite_el (buf, &kv->value.uint64,  sizeof(kv->value.uint64) ); break;
+            case GGUF_TYPE_INT64:   gguf_bwrite_el (buf, &kv->value.int64,   sizeof(kv->value.int64)  ); break;
+            case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
+            case GGUF_TYPE_BOOL:    gguf_bwrite_el (buf, &kv->value.bool_,   sizeof(kv->value.bool_)  ); break;
+            case GGUF_TYPE_STRING:  gguf_bwrite_str(buf, &kv->value.str                               ); break;
+            case GGUF_TYPE_ARRAY:
+                {
+                    gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
+                    gguf_bwrite_el(buf, &kv->value.arr.n,    sizeof(kv->value.arr.n)   );
+
+                    switch (kv->value.arr.type) {
+                        case GGUF_TYPE_UINT8:
+                        case GGUF_TYPE_INT8:
+                        case GGUF_TYPE_UINT16:
+                        case GGUF_TYPE_INT16:
+                        case GGUF_TYPE_UINT32:
+                        case GGUF_TYPE_INT32:
+                        case GGUF_TYPE_FLOAT32:
+                        case GGUF_TYPE_UINT64:
+                        case GGUF_TYPE_INT64:
+                        case GGUF_TYPE_FLOAT64:
+                        case GGUF_TYPE_BOOL:
+                            {
+                                gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
+                            } break;
+                        case GGUF_TYPE_STRING:
+                            {
+                                for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
+                                    gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
+                                }
+                            } break;
+                        case GGUF_TYPE_ARRAY:
+                        default: GGML_ABORT("invalid type");
+                    }
+                } break;
+            default: GGML_ABORT("invalid type");
+        }
+    }
+
+    // write tensor infos
+    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
+        struct gguf_tensor_info * info = &ctx->infos[i];
+
+        gguf_bwrite_str(buf, &info->name);
+        gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
+        for (uint32_t j = 0; j < info->n_dims; ++j) {
+            gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
+        }
+        gguf_bwrite_el(buf, &info->type,   sizeof(info->type));
+        gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
+    }
+
+    // we require the data section to be aligned, so take into account any padding
+    {
+        const size_t offset     = buf->offset;
+        const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
+
+        if (offset_pad != offset) {
+            uint8_t pad = 0;
+            for (size_t i = 0; i < offset_pad - offset; ++i) {
+                gguf_bwrite_el(buf, &pad, sizeof(pad));
+            }
+        }
+    }
+
+    if (only_meta) {
+        return;
+    }
+
+    size_t offset = 0;
+
+    // write tensor data
+    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
+        struct gguf_tensor_info * info = &ctx->infos[i];
+
+        const size_t size     = info->size;
+        const size_t size_pad = GGML_PAD(size, ctx->alignment);
+
+        gguf_bwrite_el(buf, info->data, size);
+
+        if (size_pad != size) {
+            uint8_t pad = 0;
+            for (size_t j = 0; j < size_pad - size; ++j) {
+                gguf_bwrite_el(buf, &pad, sizeof(pad));
+            }
+        }
+
+        GGML_ASSERT(offset == info->offset);
+
+        offset += size_pad;
+    }
+}
+
+void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
+    FILE * file = ggml_fopen(fname, "wb");
+    if (!file) {
+        GGML_ABORT("failed to open file for writing");
+    }
+
+    struct gguf_buf buf = gguf_buf_init(16*1024);
+
+    gguf_write_to_buf(ctx, &buf, only_meta);
+
+    fwrite(buf.data, 1, buf.offset, file);
+
+    gguf_buf_free(buf);
+
+    fclose(file);
+}
+
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
+    // no allocs - only compute size
+    struct gguf_buf buf = gguf_buf_init(0);
+
+    gguf_write_to_buf(ctx, &buf, true);
+
+    return buf.offset;
+}
+
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
+    struct gguf_buf buf = gguf_buf_init(16*1024);
+
+    gguf_write_to_buf(ctx, &buf, true);
+
+    memcpy(data, buf.data, buf.offset);
+
+    gguf_buf_free(buf);
+}
+
+void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
+    g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
+    g_logger_state.log_callback_user_data = user_data;
+}
+
+void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) {
+    p->n_threads  = n_threads;
+    p->prio       = 0;     // default priority (usually means normal or inherited)
+    p->poll       = 50;    // hybrid-polling enabled
+    p->strict_cpu = false; // no strict placement (all threads share same cpumask)
+    p->paused     = false; // threads are ready to go
+    memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited)
+}
+
+struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) {
+    struct ggml_threadpool_params p;
+    ggml_threadpool_params_init(&p, n_threads);
+    return p;
+}
+
+bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) {
+    if (p0->n_threads      != p1->n_threads  )    return false;
+    if (p0->prio           != p1->prio       )    return false;
+    if (p0->poll           != p1->poll       )    return false;
+    if (p0->strict_cpu     != p1->strict_cpu )    return false;
+    return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0;
+}
diff --git a/llama/ggml.h b/llama/ggml.h
new file mode 100644
index 000000000..621362c83
--- /dev/null
+++ b/llama/ggml.h
@@ -0,0 +1,2338 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+//
+// GGML Tensor Library
+//
+// This documentation is still a work in progress.
+// If you wish some specific topics to be covered, feel free to drop a comment:
+//
+//   https://github.com/ggerganov/whisper.cpp/issues/40
+//
+// ## Overview
+//
+// This library implements:
+//
+//  - a set of tensor operations
+//  - automatic differentiation
+//  - basic optimization algorithms
+//
+// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
+// but is not limited to, the following:
+//
+//  - linear regression
+//  - support vector machines
+//  - neural networks
+//
+// The library allows the user to define a certain function using the available tensor operations. This function
+// definition is represented internally via a computation graph. Each tensor operation in the function definition
+// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
+// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
+// using one of the available optimization algorithms.
+//
+// For example, here we define the function: f(x) = a*x^2 + b
+//
+//   {
+//       struct ggml_init_params params = {
+//           .mem_size   = 16*1024*1024,
+//           .mem_buffer = NULL,
+//       };
+//
+//       // memory allocation happens here
+//       struct ggml_context * ctx = ggml_init(params);
+//
+//       struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//
+//       ggml_set_param(ctx, x); // x is an input variable
+//
+//       struct ggml_tensor * a  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//       struct ggml_tensor * b  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+//       struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
+//       struct ggml_tensor * f  = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
+//
+//       ...
+//   }
+//
+// Notice that the function definition above does not involve any actual computation. The computation is performed only
+// when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
+//
+//   {
+//       ...
+//
+//       struct ggml_cgraph * gf = ggml_new_graph(ctx);
+//       ggml_build_forward_expand(gf, f);
+//
+//       // set the input variable and parameter values
+//       ggml_set_f32(x, 2.0f);
+//       ggml_set_f32(a, 3.0f);
+//       ggml_set_f32(b, 4.0f);
+//
+//       ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
+//
+//       printf("f = %f\n", ggml_get_f32_1d(f, 0));
+//
+//       ...
+//   }
+//
+// The actual computation is performed in the ggml_graph_compute() function.
+//
+// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
+// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
+// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
+// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
+// actually needed.
+//
+// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
+// differentiation and optimization algorithms.
+//
+// The described approach allows to define the function graph once and then compute its forward or backward graphs
+// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
+// the user can avoid the memory allocation overhead at runtime.
+//
+// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
+// citizens, but in theory the library can be extended to support FP8 and integer data types.
+//
+// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
+// and binary operations. Most of the available operations fall into one of these two categories. With time, it became
+// clear that the library needs to support more complex operations. The way to support these operations is not clear
+// yet, but a few examples are demonstrated in the following operations:
+//
+//   - ggml_permute()
+//   - ggml_conv_1d_1s()
+//   - ggml_conv_1d_2s()
+//
+// For each tensor operator, the library implements a forward and backward computation function. The forward function
+// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
+// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
+// calculus class, or watch the following video:
+//
+//   What is Automatic Differentiation?
+//   https://www.youtube.com/watch?v=wG_nF1awSSY
+//
+//
+// ## Tensor data (struct ggml_tensor)
+//
+// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
+// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
+// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
+//
+//   {
+//       struct ggml_tensor * c = ggml_add(ctx, a, b);
+//
+//       assert(c->src[0] == a);
+//       assert(c->src[1] == b);
+//   }
+//
+// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
+// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
+// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
+// permutation. All tensor operations have to take the stride into account and not assume that the tensor is
+// contiguous in memory.
+//
+// The data of the tensor is accessed via the "data" pointer. For example:
+//
+//   {
+//       const int nx = 2;
+//       const int ny = 3;
+//
+//       struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
+//
+//       for (int y = 0; y < ny; y++) {
+//           for (int x = 0; x < nx; x++) {
+//               *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
+//           }
+//       }
+//
+//       ...
+//   }
+//
+// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
+//
+// ## The matrix multiplication operator (ggml_mul_mat)
+//
+// TODO
+//
+//
+// ## Multi-threading
+//
+// TODO
+//
+//
+// ## Overview of ggml.c
+//
+// TODO
+//
+//
+// ## SIMD optimizations
+//
+// TODO
+//
+//
+// ## Debugging ggml
+//
+// TODO
+//
+//
+
+#ifdef GGML_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef GGML_BUILD
+#            define GGML_API __declspec(dllexport) extern
+#        else
+#            define GGML_API __declspec(dllimport) extern
+#        endif
+#    else
+#        define GGML_API __attribute__ ((visibility ("default"))) extern
+#    endif
+#else
+#    define GGML_API extern
+#endif
+
+// TODO: support for clang
+#ifdef __GNUC__
+#    define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
+#elif defined(_MSC_VER)
+#    define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
+#else
+#    define GGML_DEPRECATED(func, hint) func
+#endif
+
+#ifndef __GNUC__
+#    define GGML_ATTRIBUTE_FORMAT(...)
+#elif defined(__MINGW32__)
+#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
+#else
+#    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
+#endif
+
+#include 
+#include 
+#include 
+#include 
+
+#define GGML_FILE_MAGIC   0x67676d6c // "ggml"
+#define GGML_FILE_VERSION 2
+
+#define GGML_QNT_VERSION        2    // bump this on quantization format changes
+#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
+
+#define GGML_MAX_DIMS           4
+#define GGML_MAX_PARAMS         2048
+#define GGML_MAX_SRC            10
+#define GGML_MAX_N_THREADS      512
+#define GGML_MAX_OP_PARAMS      64
+
+#ifndef GGML_MAX_NAME
+#   define GGML_MAX_NAME        64
+#endif
+
+#define GGML_DEFAULT_N_THREADS  4
+#define GGML_DEFAULT_GRAPH_SIZE 2048
+
+#if UINTPTR_MAX == 0xFFFFFFFF
+    #define GGML_MEM_ALIGN 4
+#else
+    #define GGML_MEM_ALIGN 16
+#endif
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
+#define GGML_ROPE_TYPE_NEOX   2
+#define GGML_ROPE_TYPE_MROPE  8
+#define GGML_ROPE_TYPE_VISION 24
+
+#define GGUF_MAGIC "GGUF"
+
+#define GGUF_VERSION 3
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#define GGML_UNUSED(x) (void)(x)
+
+#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
+
+#ifndef NDEBUG
+#   define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
+#elif defined(__GNUC__)
+#   define GGML_UNREACHABLE() __builtin_unreachable()
+#elif defined(_MSC_VER)
+#   define GGML_UNREACHABLE() __assume(0)
+#else
+#   define GGML_UNREACHABLE() ((void) 0)
+#endif
+
+#ifdef __cplusplus
+#   define GGML_NORETURN [[noreturn]]
+#elif defined(_MSC_VER)
+#   define GGML_NORETURN __declspec(noreturn)
+#else
+#   define GGML_NORETURN _Noreturn
+#endif
+
+#define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
+#define GGML_ASSERT(x) if (!(x)) GGML_ABORT("GGML_ASSERT(%s) failed", #x)
+
+// used to copy the number of elements and stride in bytes of tensors into local variables.
+// main purpose is to reduce code duplication and improve readability.
+//
+// example:
+//
+//    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+//    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb);
+//
+#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
+    const type prefix##0 = (pointer)->array[0]; \
+    GGML_UNUSED(prefix##0);
+#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_1    (type, prefix, pointer, array) \
+    const type prefix##1 = (pointer)->array[1]; \
+    GGML_UNUSED(prefix##1);
+#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_2    (type, prefix, pointer, array) \
+    const type prefix##2 = (pointer)->array[2]; \
+    GGML_UNUSED(prefix##2);
+#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
+    GGML_TENSOR_LOCALS_3  (type, prefix, pointer, array) \
+    const type prefix##3 = (pointer)->array[3]; \
+    GGML_UNUSED(prefix##3);
+
+#define GGML_TENSOR_UNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
+#define GGML_TENSOR_BINARY_OP_LOCALS01 \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4)
+    GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...);
+
+    enum ggml_status {
+        GGML_STATUS_ALLOC_FAILED = -2,
+        GGML_STATUS_FAILED = -1,
+        GGML_STATUS_SUCCESS = 0,
+        GGML_STATUS_ABORTED = 1,
+    };
+
+    // get ggml_status name string
+    GGML_API const char * ggml_status_to_string(enum ggml_status status);
+
+    // ieee 754-2008 half-precision float16
+    // todo: make this not an integral type
+    typedef uint16_t ggml_fp16_t;
+    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t);
+    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
+    GGML_API void        ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
+    GGML_API void        ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
+
+    // google brain half-precision bfloat16
+    typedef struct { uint16_t bits; } ggml_bf16_t;
+    GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
+    GGML_API float       ggml_bf16_to_fp32(ggml_bf16_t);  // consider just doing << 16
+    GGML_API void        ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
+    GGML_API void        ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t);
+    GGML_API void        ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
+
+    struct ggml_object;
+    struct ggml_context;
+    struct ggml_cgraph;
+
+    // NOTE: always add types at the end of the enum to keep backward compatibility
+    enum ggml_type {
+        GGML_TYPE_F32     = 0,
+        GGML_TYPE_F16     = 1,
+        GGML_TYPE_Q4_0    = 2,
+        GGML_TYPE_Q4_1    = 3,
+        // GGML_TYPE_Q4_2 = 4, support has been removed
+        // GGML_TYPE_Q4_3 = 5, support has been removed
+        GGML_TYPE_Q5_0    = 6,
+        GGML_TYPE_Q5_1    = 7,
+        GGML_TYPE_Q8_0    = 8,
+        GGML_TYPE_Q8_1    = 9,
+        GGML_TYPE_Q2_K    = 10,
+        GGML_TYPE_Q3_K    = 11,
+        GGML_TYPE_Q4_K    = 12,
+        GGML_TYPE_Q5_K    = 13,
+        GGML_TYPE_Q6_K    = 14,
+        GGML_TYPE_Q8_K    = 15,
+        GGML_TYPE_IQ2_XXS = 16,
+        GGML_TYPE_IQ2_XS  = 17,
+        GGML_TYPE_IQ3_XXS = 18,
+        GGML_TYPE_IQ1_S   = 19,
+        GGML_TYPE_IQ4_NL  = 20,
+        GGML_TYPE_IQ3_S   = 21,
+        GGML_TYPE_IQ2_S   = 22,
+        GGML_TYPE_IQ4_XS  = 23,
+        GGML_TYPE_I8      = 24,
+        GGML_TYPE_I16     = 25,
+        GGML_TYPE_I32     = 26,
+        GGML_TYPE_I64     = 27,
+        GGML_TYPE_F64     = 28,
+        GGML_TYPE_IQ1_M   = 29,
+        GGML_TYPE_BF16    = 30,
+        // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
+        // GGML_TYPE_Q4_0_4_8 = 32,
+        // GGML_TYPE_Q4_0_8_8 = 33,
+        GGML_TYPE_TQ1_0   = 34,
+        GGML_TYPE_TQ2_0   = 35,
+        // GGML_TYPE_IQ4_NL_4_4 = 36,
+        // GGML_TYPE_IQ4_NL_4_8 = 37,
+        // GGML_TYPE_IQ4_NL_8_8 = 38,
+        GGML_TYPE_COUNT   = 39,
+    };
+
+    // precision
+    enum ggml_prec {
+        GGML_PREC_DEFAULT,
+        GGML_PREC_F32,
+    };
+
+    enum ggml_backend_type {
+        GGML_BACKEND_TYPE_CPU = 0,
+        GGML_BACKEND_TYPE_GPU = 10,
+        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
+    };
+
+    // model file types
+    enum ggml_ftype {
+        GGML_FTYPE_UNKNOWN        = -1,
+        GGML_FTYPE_ALL_F32        = 0,
+        GGML_FTYPE_MOSTLY_F16     = 1,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_0    = 2,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_1    = 3,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
+        GGML_FTYPE_MOSTLY_Q8_0    = 7,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_0    = 8,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_1    = 9,  // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q2_K    = 10, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q3_K    = 11, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q4_K    = 12, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q5_K    = 13, // except 1d tensors
+        GGML_FTYPE_MOSTLY_Q6_K    = 14, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ2_XS  = 16, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ1_S   = 18, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ4_NL  = 19, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ3_S   = 20, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ2_S   = 21, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
+        GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
+        GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
+    };
+
+    // available tensor operations:
+    enum ggml_op {
+        GGML_OP_NONE = 0,
+
+        GGML_OP_DUP,
+        GGML_OP_ADD,
+        GGML_OP_ADD1,
+        GGML_OP_ACC,
+        GGML_OP_SUB,
+        GGML_OP_MUL,
+        GGML_OP_DIV,
+        GGML_OP_SQR,
+        GGML_OP_SQRT,
+        GGML_OP_LOG,
+        GGML_OP_SIN,
+        GGML_OP_COS,
+        GGML_OP_SUM,
+        GGML_OP_SUM_ROWS,
+        GGML_OP_MEAN,
+        GGML_OP_ARGMAX,
+        GGML_OP_COUNT_EQUAL,
+        GGML_OP_REPEAT,
+        GGML_OP_REPEAT_BACK,
+        GGML_OP_CONCAT,
+        GGML_OP_SILU_BACK,
+        GGML_OP_NORM, // normalize
+        GGML_OP_RMS_NORM,
+        GGML_OP_RMS_NORM_BACK,
+        GGML_OP_GROUP_NORM,
+
+        GGML_OP_MUL_MAT,
+        GGML_OP_MUL_MAT_ID,
+        GGML_OP_OUT_PROD,
+
+        GGML_OP_SCALE,
+        GGML_OP_SET,
+        GGML_OP_CPY,
+        GGML_OP_CONT,
+        GGML_OP_RESHAPE,
+        GGML_OP_VIEW,
+        GGML_OP_PERMUTE,
+        GGML_OP_TRANSPOSE,
+        GGML_OP_GET_ROWS,
+        GGML_OP_GET_ROWS_BACK,
+        GGML_OP_DIAG,
+        GGML_OP_DIAG_MASK_INF,
+        GGML_OP_DIAG_MASK_ZERO,
+        GGML_OP_SOFT_MAX,
+        GGML_OP_SOFT_MAX_BACK,
+        GGML_OP_ROPE,
+        GGML_OP_ROPE_BACK,
+        GGML_OP_CLAMP,
+        GGML_OP_CONV_TRANSPOSE_1D,
+        GGML_OP_IM2COL,
+        GGML_OP_IM2COL_BACK,
+        GGML_OP_CONV_TRANSPOSE_2D,
+        GGML_OP_POOL_1D,
+        GGML_OP_POOL_2D,
+        GGML_OP_POOL_2D_BACK,
+        GGML_OP_UPSCALE, // nearest interpolate
+        GGML_OP_PAD,
+        GGML_OP_PAD_REFLECT_1D,
+        GGML_OP_UNPAD,
+        GGML_OP_ARANGE,
+        GGML_OP_TIMESTEP_EMBEDDING,
+        GGML_OP_ARGSORT,
+        GGML_OP_LEAKY_RELU,
+
+        GGML_OP_FLASH_ATTN_EXT,
+        GGML_OP_FLASH_ATTN_BACK,
+        GGML_OP_SSM_CONV,
+        GGML_OP_SSM_SCAN,
+        GGML_OP_WIN_PART,
+        GGML_OP_WIN_UNPART,
+        GGML_OP_GET_REL_POS,
+        GGML_OP_ADD_REL_POS,
+        GGML_OP_RWKV_WKV6,
+
+        GGML_OP_UNARY,
+
+        GGML_OP_MAP_UNARY,
+        GGML_OP_MAP_BINARY,
+
+        GGML_OP_MAP_CUSTOM1_F32,
+        GGML_OP_MAP_CUSTOM2_F32,
+        GGML_OP_MAP_CUSTOM3_F32,
+
+        GGML_OP_MAP_CUSTOM1,
+        GGML_OP_MAP_CUSTOM2,
+        GGML_OP_MAP_CUSTOM3,
+
+        GGML_OP_CROSS_ENTROPY_LOSS,
+        GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+        GGML_OP_OPT_STEP_ADAMW,
+
+        GGML_OP_COUNT,
+    };
+
+    enum ggml_unary_op {
+        GGML_UNARY_OP_ABS,
+        GGML_UNARY_OP_SGN,
+        GGML_UNARY_OP_NEG,
+        GGML_UNARY_OP_STEP,
+        GGML_UNARY_OP_TANH,
+        GGML_UNARY_OP_ELU,
+        GGML_UNARY_OP_RELU,
+        GGML_UNARY_OP_SIGMOID,
+        GGML_UNARY_OP_GELU,
+        GGML_UNARY_OP_GELU_QUICK,
+        GGML_UNARY_OP_SILU,
+        GGML_UNARY_OP_HARDSWISH,
+        GGML_UNARY_OP_HARDSIGMOID,
+        GGML_UNARY_OP_EXP,
+
+        GGML_UNARY_OP_COUNT,
+    };
+
+    enum ggml_object_type {
+        GGML_OBJECT_TYPE_TENSOR,
+        GGML_OBJECT_TYPE_GRAPH,
+        GGML_OBJECT_TYPE_WORK_BUFFER
+    };
+
+    enum ggml_log_level {
+        GGML_LOG_LEVEL_NONE  = 0,
+        GGML_LOG_LEVEL_DEBUG = 1,
+        GGML_LOG_LEVEL_INFO  = 2,
+        GGML_LOG_LEVEL_WARN  = 3,
+        GGML_LOG_LEVEL_ERROR = 4,
+        GGML_LOG_LEVEL_CONT  = 5, // continue previous log
+    };
+
+    // this tensor...
+    enum ggml_tensor_flag {
+        GGML_TENSOR_FLAG_INPUT  =  1, // ...is an input for the GGML compute graph
+        GGML_TENSOR_FLAG_OUTPUT =  2, // ...is an output for the GGML compute graph
+        GGML_TENSOR_FLAG_PARAM  =  4, // ...contains trainable parameters
+        GGML_TENSOR_FLAG_LOSS   =  8, // ...defines loss for numerical optimization (multiple loss tensors add up)
+    };
+
+    struct ggml_init_params {
+        // memory pool
+        size_t mem_size;   // bytes
+        void * mem_buffer; // if NULL, memory will be allocated internally
+        bool   no_alloc;   // don't allocate memory for the tensor data
+    };
+
+    // n-dimensional tensor
+    struct ggml_tensor {
+        enum ggml_type type;
+
+        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
+
+        struct ggml_backend_buffer * buffer;
+
+        int64_t ne[GGML_MAX_DIMS]; // number of elements
+        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
+                                   // nb[0] = ggml_type_size(type)
+                                   // nb[1] = nb[0]   * (ne[0] / ggml_blck_size(type)) + padding
+                                   // nb[i] = nb[i-1] * ne[i-1]
+
+        // compute data
+        enum ggml_op op;
+
+        // op params - allocated as int32_t for alignment
+        int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+
+        int32_t flags;
+
+        struct ggml_tensor * src[GGML_MAX_SRC];
+
+        // source tensor and offset for views
+        struct ggml_tensor * view_src;
+        size_t               view_offs;
+
+        void * data;
+
+        char name[GGML_MAX_NAME];
+
+        void * extra; // extra things e.g. for ggml-cuda.cu
+
+        char padding[8];
+    };
+
+    static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
+
+    // Abort callback
+    // If not NULL, called before ggml computation
+    // If it returns true, the computation is aborted
+    typedef bool (*ggml_abort_callback)(void * data);
+
+
+    //
+    // GUID
+    //
+
+    // GUID types
+    typedef uint8_t ggml_guid[16];
+    typedef ggml_guid * ggml_guid_t;
+
+    GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);
+
+    // misc
+
+    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
+    GGML_API int64_t ggml_time_ms(void);
+    GGML_API int64_t ggml_time_us(void);
+    GGML_API int64_t ggml_cycles(void);
+    GGML_API int64_t ggml_cycles_per_ms(void);
+
+    // accepts a UTF-8 path, even on Windows
+    GGML_API FILE *  ggml_fopen(const char * fname, const char * mode);
+
+    GGML_API void    ggml_print_object (const struct ggml_object * obj);
+    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
+
+    GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nrows     (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes    (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
+
+    GGML_API int64_t ggml_blck_size(enum ggml_type type);
+    GGML_API size_t  ggml_type_size(enum ggml_type type);             // size in bytes for all elements in a block
+    GGML_API size_t  ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
+
+    GGML_DEPRECATED(
+    GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
+    "use ggml_row_size() instead");
+
+    GGML_API const char * ggml_type_name(enum ggml_type type);
+    GGML_API const char * ggml_op_name  (enum ggml_op   op);
+    GGML_API const char * ggml_op_symbol(enum ggml_op   op);
+
+    GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
+
+    GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
+
+    GGML_API bool    ggml_is_quantized(enum ggml_type type);
+
+    // TODO: temporary until model loading of ggml examples is refactored
+    GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
+
+    GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_permuted  (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_empty     (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_scalar    (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_vector    (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_matrix    (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_3d        (const struct ggml_tensor * tensor);
+    GGML_API int  ggml_n_dims       (const struct ggml_tensor * tensor); // returns 1 for scalars
+
+    GGML_API bool ggml_is_contiguous  (const struct ggml_tensor * tensor);
+    GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
+    GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
+    GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
+
+    GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+    GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+
+    GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
+
+    // use this to compute the memory overhead of a tensor
+    GGML_API size_t ggml_tensor_overhead(void);
+
+    GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
+
+    // main
+
+    GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
+    GGML_API void                  ggml_reset(struct ggml_context * ctx);
+    GGML_API void                  ggml_free (struct ggml_context * ctx);
+
+    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
+
+    GGML_API bool    ggml_get_no_alloc(struct ggml_context * ctx);
+    GGML_API void    ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
+
+    GGML_API void *  ggml_get_mem_buffer     (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_mem_size       (const struct ggml_context * ctx);
+    GGML_API size_t  ggml_get_max_tensor_size(const struct ggml_context * ctx);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int    n_dims,
+            const int64_t *ne);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_1d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_2d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_3d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2);
+
+    GGML_API struct ggml_tensor * ggml_new_tensor_4d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2,
+            int64_t ne3);
+
+    GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes);
+
+    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
+
+    // Context tensor enumeration and lookup
+    GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
+    GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
+
+    // Converts a flat index into coordinates
+    GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
+
+    GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+
+    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
+    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
+
+    GGML_API const char *         ggml_get_name   (const struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_name   (      struct ggml_tensor * tensor, const char * name);
+    GGML_ATTRIBUTE_FORMAT(2, 3)
+    GGML_API struct ggml_tensor * ggml_format_name(      struct ggml_tensor * tensor, const char * fmt, ...);
+
+    // Tensor flags
+    GGML_API void ggml_set_input(struct ggml_tensor * tensor);
+    GGML_API void ggml_set_output(struct ggml_tensor * tensor);
+    GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
+    GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
+
+    //
+    // operations on tensors with backpropagation
+    //
+
+    GGML_API struct ggml_tensor * ggml_dup(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_dup_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_add(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            enum   ggml_type      type);
+
+    GGML_API struct ggml_tensor * ggml_add1(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add1_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // dst = a
+    // view(dst, nb1, nb2, nb3, offset) += b
+    // return dst
+    GGML_API struct ggml_tensor * ggml_acc(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_acc_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_sub(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sub_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_mul(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_mul_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_div(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_div_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sqr(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqr_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqrt(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqrt_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_log(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_log_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sin(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sin_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_cos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_cos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // return scalar
+    GGML_API struct ggml_tensor * ggml_sum(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
+    GGML_API struct ggml_tensor * ggml_sum_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // mean along rows
+    GGML_API struct ggml_tensor * ggml_mean(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // argmax along rows
+    GGML_API struct ggml_tensor * ggml_argmax(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // count number of equal elements in a and b
+    GGML_API struct ggml_tensor * ggml_count_equal(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // if a is the same shape as b, and a is not parameter, return a
+    // otherwise, return a new tensor: repeat(a) to fit in b
+    GGML_API struct ggml_tensor * ggml_repeat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // sums repetitions in a into shape of b
+    GGML_API struct ggml_tensor * ggml_repeat_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // concat a and b along dim
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_concat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   dim);
+
+    GGML_API struct ggml_tensor * ggml_abs(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_abs_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sgn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sgn_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_neg(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_neg_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_step(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_step_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_tanh(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_tanh_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_elu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_elu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_leaky_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a, float negative_slope, bool inplace);
+
+    GGML_API struct ggml_tensor * ggml_relu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sigmoid(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_silu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_silu_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_silu_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // hardswish(x) = x * relu6(x + 3) / 6
+    GGML_API struct ggml_tensor * ggml_hardswish(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // hardsigmoid(x) = relu6(x + 3) / 6
+    GGML_API struct ggml_tensor * ggml_hardsigmoid(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_exp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_exp_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // normalize along rows
+    GGML_API struct ggml_tensor * ggml_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_rms_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 eps);
+
+    // group normalize along ne0*ne1*n_groups
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_group_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups,
+            float                 eps);
+
+    GGML_API struct ggml_tensor * ggml_group_norm_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_groups,
+            float                 eps);
+
+    // a - x
+    // b - dy
+    GGML_API struct ggml_tensor * ggml_rms_norm_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            float                 eps);
+
+    // A: k columns, n rows => [ne03, ne02, n, k]
+    // B: k columns, m rows  (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
+    // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
+    GGML_API struct ggml_tensor * ggml_mul_mat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // change the precision of a matrix multiplication
+    // set to GGML_PREC_F32 for higher precision (useful for phi-2)
+    GGML_API void ggml_mul_mat_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
+    // indirect matrix multiplication
+    GGML_API struct ggml_tensor * ggml_mul_mat_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * as,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids);
+
+    // A: m columns, n rows,
+    // B: p columns, n rows,
+    // result is m columns, p rows
+    GGML_API struct ggml_tensor * ggml_out_prod(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    //
+    // operations on tensors without backpropagation
+    //
+
+    GGML_API struct ggml_tensor * ggml_scale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 s);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_scale_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 s);
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset); // in bytes
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                nb2,
+            size_t                nb3,
+            size_t                offset); // in bytes
+
+    GGML_API struct ggml_tensor * ggml_set_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset); // in bytes
+
+    GGML_API struct ggml_tensor * ggml_set_1d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                offset); // in bytes
+
+    // b -> view(a,offset,nb1,nb2,3), return modified a
+    GGML_API struct ggml_tensor * ggml_set_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset); // in bytes
+
+    // b -> view(a,offset,nb1,nb2,3), return view(a)
+    GGML_API struct ggml_tensor * ggml_set_2d_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            size_t                nb1,
+            size_t                offset); // in bytes
+
+    // a -> b, return view(b)
+    GGML_API struct ggml_tensor * ggml_cpy(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_cast(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum   ggml_type      type);
+
+    // make contiguous
+    GGML_API struct ggml_tensor * ggml_cont(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // make contiguous, with new shape
+    GGML_API struct ggml_tensor * ggml_cont_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
+    GGML_API struct ggml_tensor * ggml_cont_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    GGML_API struct ggml_tensor * ggml_cont_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    GGML_API struct ggml_tensor * ggml_cont_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
+    // return view(a), b specifies the new shape
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0);
+
+    GGML_API struct ggml_tensor * ggml_reshape_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    GGML_API struct ggml_tensor * ggml_reshape_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3);
+
+    // offset in bytes
+    GGML_API struct ggml_tensor * ggml_view_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            size_t                nb1, // row stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_4d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            int64_t               ne3,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                nb3,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_permute(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   axis0,
+            int                   axis1,
+            int                   axis2,
+            int                   axis3);
+
+    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+    GGML_API struct ggml_tensor * ggml_transpose(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // supports 3D: a->ne[2] == b->ne[1]
+    GGML_API struct ggml_tensor * ggml_get_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // data
+            struct ggml_tensor  * b); // row indices
+
+    GGML_API struct ggml_tensor * ggml_get_rows_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // gradients of ggml_get_rows result
+            struct ggml_tensor  * b,  // row indices
+            struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape
+
+    GGML_API struct ggml_tensor * ggml_diag(
+        struct ggml_context     * ctx,
+        struct ggml_tensor      * a);
+
+    // set elements above the diagonal to -INF
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // set elements above the diagonal to 0
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    GGML_API struct ggml_tensor * ggml_soft_max(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // fused soft_max(a*scale + mask*(ALiBi slope))
+    // mask is optional
+    // max_bias = 0.0f for no ALiBi
+    GGML_API struct ggml_tensor * ggml_soft_max_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale,
+            float                 max_bias);
+
+    GGML_API struct ggml_tensor * ggml_soft_max_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // rotary position embedding
+    // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
+    // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
+    //
+    // b is an int32 vector with size a->ne[2], it contains the positions
+    GGML_API struct ggml_tensor * ggml_rope(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode);
+
+    // custom RoPE
+    // c is freq factors (e.g. phi3-128k), (optional)
+    GGML_API struct ggml_tensor * ggml_rope_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    GGML_API struct ggml_tensor * ggml_rope_multi(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   sections[4],
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow),
+        "use ggml_rope_ext_inplace instead");
+
+    // compute correction dims for YaRN RoPE scaling
+    GGML_API void ggml_rope_yarn_corr_dims(
+        int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
+
+    // rotary position embedding backward, i.e compute dx from dy
+    // a - dy
+    GGML_API struct ggml_tensor * ggml_rope_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a, // gradients of ggml_rope result
+            struct ggml_tensor  * b, // positions
+            struct ggml_tensor  * c, // freq factors
+            int                   n_dims,
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+    // clamp
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_clamp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 min,
+            float                 max);
+
+    // im2col
+    // converts data into a format that effectively results in a convolution when combined with matrix multiplication
+    GGML_API struct ggml_tensor * ggml_im2col(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                   s0, // stride dimension 0
+            int                   s1, // stride dimension 1
+            int                   p0, // padding dimension 0
+            int                   p1, // padding dimension 1
+            int                   d0, // dilation dimension 0
+            int                   d1, // dilation dimension 1
+            bool                  is_2D,
+            enum ggml_type        dst_type);
+
+    GGML_API struct ggml_tensor * ggml_im2col_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,  // convolution kernel
+        struct ggml_tensor  * b,  // gradient of im2col output
+        int64_t             * ne, // shape of im2col input
+        int                   s0, // stride dimension 0
+        int                   s1, // stride dimension 1
+        int                   p0, // padding dimension 0
+        int                   p1, // padding dimension 1
+        int                   d0, // dilation dimension 0
+        int                   d1, // dilation dimension 1
+        bool                  is_2D);
+
+    GGML_API struct ggml_tensor * ggml_conv_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
+
+    // conv_1d with padding = half
+    // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
+    GGML_API struct ggml_tensor* ggml_conv_1d_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                   s,  // stride
+            int                   d); // dilation
+
+    // depthwise
+    // TODO: this is very likely wrong for some cases! - needs more testing
+    GGML_API struct ggml_tensor * ggml_conv_1d_dw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
+
+    GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   d0); // dilation
+
+    GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride
+            int                   p0,  // padding
+            int                   d0); // dilation
+
+    GGML_API struct ggml_tensor * ggml_conv_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,   // convolution kernel
+            struct ggml_tensor  * b,   // data
+            int                   s0,  // stride dimension 0
+            int                   s1,  // stride dimension 1
+            int                   p0,  // padding dimension 0
+            int                   p1,  // padding dimension 1
+            int                   d0,  // dilation dimension 0
+            int                   d1); // dilation dimension 1
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is equal to kernel size
+    // padding is zero
+    // example:
+    // a:     16   16    3  768
+    // b:   1024 1024    3    1
+    // res:   64   64  768    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // kernel size is a->ne[0] x a->ne[1]
+    // stride is 1
+    // padding is half
+    // example:
+    // a:      3    3    256  256
+    // b:     64   64    256    1
+    // res:   64   64    256    1
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // depthwise
+    GGML_API struct ggml_tensor * ggml_conv_2d_dw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // convolution kernel
+            struct ggml_tensor  * b,  // data
+            int                  s0,  // stride dimension 0
+            int                  s1,  // stride dimension 1
+            int                  p0,  // padding dimension 0
+            int                  p1,  // padding dimension 1
+            int                  d0,  // dilation dimension 0
+            int                  d1); // dilation dimension 1
+
+    GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            int                   stride);
+
+    enum ggml_op_pool {
+        GGML_OP_POOL_MAX,
+        GGML_OP_POOL_AVG,
+        GGML_OP_POOL_COUNT,
+    };
+
+    GGML_API struct ggml_tensor * ggml_pool_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_op_pool     op,
+            int                   k0, // kernel size
+            int                   s0, // stride
+            int                   p0); // padding
+
+    // the result will have 2*p0 padding for the first dimension
+    // and 2*p1 padding for the second dimension
+    GGML_API struct ggml_tensor * ggml_pool_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_op_pool     op,
+            int                   k0,
+            int                   k1,
+            int                   s0,
+            int                   s1,
+            float                 p0,
+            float                 p1);
+
+    GGML_API struct ggml_tensor * ggml_pool_2d_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * af, // "a"/input used in forward pass
+            enum ggml_op_pool     op,
+            int                   k0,
+            int                   k1,
+            int                   s0,
+            int                   s1,
+            float                 p0,
+            float                 p1);
+
+    // nearest interpolate
+    // multiplies ne0 and ne1 by scale factor
+    // used in stable-diffusion
+    GGML_API struct ggml_tensor * ggml_upscale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   scale_factor);
+
+    // nearest interpolate
+    // nearest interpolate to specified dimensions
+    // used in tortoise.cpp
+    GGML_API struct ggml_tensor * ggml_upscale_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   ne0,
+            int                   ne1,
+            int                   ne2,
+            int                   ne3);
+
+    // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+    GGML_API struct ggml_tensor * ggml_pad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
+    // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
+    GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   p0,
+            int                   p1);
+
+    // unpad each dimension: [x, ..., x, y, ..., y] -> [x, ..., x]
+    GGML_API struct ggml_tensor * ggml_unpad(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                  p0,
+            int                  p1,
+            int                  p2,
+            int                  p3);
+
+    // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
+    // timesteps: [N,]
+    // return: [N, dim]
+    GGML_API struct ggml_tensor * ggml_timestep_embedding(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * timesteps,
+            int                   dim,
+            int                   max_period);
+
+    // sort rows
+    enum ggml_sort_order {
+        GGML_SORT_ORDER_ASC,
+        GGML_SORT_ORDER_DESC,
+    };
+
+    GGML_API struct ggml_tensor * ggml_argsort(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            enum ggml_sort_order  order);
+
+    GGML_API struct ggml_tensor * ggml_arange(
+            struct ggml_context * ctx,
+            float                 start,
+            float                 stop,
+            float                 step);
+
+    // top k elements per row
+    GGML_API struct ggml_tensor * ggml_top_k(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   k);
+
+#define GGML_KQ_MASK_PAD 32
+
+    // q:    [n_embd, n_batch,     n_head,    1]
+    // k:    [n_embd, n_kv,        n_head_kv, 1]
+    // v:    [n_embd, n_kv,        n_head_kv, 1] !! not transposed !!
+    // mask: [n_kv,   n_batch_pad, 1,         1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+    // res:  [n_embd, n_head,      n_batch,   1] !! permuted !!
+    GGML_API struct ggml_tensor * ggml_flash_attn_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * mask,
+            float                 scale,
+            float                 max_bias,
+            float                 logit_softcap);
+
+    GGML_API void ggml_flash_attn_ext_set_prec(
+            struct ggml_tensor * a,
+            enum ggml_prec       prec);
+
+    GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
+            const struct ggml_tensor * a);
+
+    // TODO: needs to be adapted to ggml_flash_attn_ext
+    GGML_API struct ggml_tensor * ggml_flash_attn_back(
+           struct ggml_context * ctx,
+           struct ggml_tensor  * q,
+           struct ggml_tensor  * k,
+           struct ggml_tensor  * v,
+           struct ggml_tensor  * d,
+           bool                  masked);
+
+    GGML_API struct ggml_tensor * ggml_ssm_conv(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * sx,
+            struct ggml_tensor  * c);
+
+    GGML_API struct ggml_tensor * ggml_ssm_scan(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * s,
+            struct ggml_tensor  * x,
+            struct ggml_tensor  * dt,
+            struct ggml_tensor  * A,
+            struct ggml_tensor  * B,
+            struct ggml_tensor  * C);
+
+    // partition into non-overlapping windows with padding if needed
+    // example:
+    // a:   768   64   64    1
+    // w:    14
+    // res: 768   14   14    25
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_win_part(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   w);
+
+    // reverse of ggml_win_part
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_win_unpart(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   w0,
+            int                   h0,
+            int                   w);
+
+    GGML_API struct ggml_tensor * ggml_unary(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             enum ggml_unary_op op);
+
+    GGML_API struct ggml_tensor * ggml_unary_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_unary_op op);
+
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_get_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   qh,
+            int                   kh);
+
+    // used in sam
+    GGML_API struct ggml_tensor * ggml_add_rel_pos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
+    GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * pw,
+            struct ggml_tensor  * ph);
+
+    GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * r,
+            struct ggml_tensor  * tf,
+            struct ggml_tensor  * td,
+            struct ggml_tensor  * state);
+
+    // custom operators
+
+    typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
+    typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
+
+    typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+    typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+                   ggml_unary_op_f32_t   fun),
+        "use ggml_map_custom1 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+                   ggml_unary_op_f32_t   fun),
+        "use ggml_map_custom1_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+                   ggml_binary_op_f32_t   fun),
+        "use ggml_map_custom2 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+                   ggml_binary_op_f32_t   fun),
+        "use ggml_map_custom2_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun),
+        "use ggml_map_custom1 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+                   ggml_custom1_op_f32_t   fun),
+        "use ggml_map_custom1_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun),
+        "use ggml_map_custom2 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+                   ggml_custom2_op_f32_t   fun),
+        "use ggml_map_custom2_inplace instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun),
+        "use ggml_map_custom3 instead");
+
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
+            struct ggml_context          * ctx,
+            struct ggml_tensor           * a,
+            struct ggml_tensor           * b,
+            struct ggml_tensor           * c,
+                   ggml_custom3_op_f32_t   fun),
+        "use ggml_map_custom3_inplace instead");
+
+    // custom operators v2
+
+    typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
+    typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
+    typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
+
+#define GGML_N_TASKS_MAX (-1)
+    // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
+
+    GGML_API struct ggml_tensor * ggml_map_custom1(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            ggml_custom1_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            ggml_custom1_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            ggml_custom2_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            ggml_custom2_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            struct ggml_tensor    * c,
+            ggml_custom3_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
+            struct ggml_context   * ctx,
+            struct ggml_tensor    * a,
+            struct ggml_tensor    * b,
+            struct ggml_tensor    * c,
+            ggml_custom3_op_t       fun,
+            int                     n_tasks,
+            void                  * userdata);
+
+    // loss function
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // logits
+            struct ggml_tensor  * b); // labels
+
+    GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // logits
+            struct ggml_tensor  * b,  // labels
+            struct ggml_tensor  * c); // gradients of cross_entropy_loss result
+
+    // AdamW optimizer step
+    // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
+    // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
+    GGML_API struct ggml_tensor * ggml_opt_step_adamw(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * grad,
+            struct ggml_tensor  * m,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * adamw_params); // parameters such a the learning rate
+
+    //
+    // automatic differentiation
+    //
+
+    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+    GGML_API void ggml_build_backward_expand(
+        struct ggml_context * ctx_static,  // context for static gradients (loss + gradient accumulation)
+        struct ggml_context * ctx_compute, // context for gradient computation
+        struct ggml_cgraph  * cgraph,
+        bool                  accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
+
+    // graph allocation in a context
+    GGML_API struct ggml_cgraph * ggml_new_graph       (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
+    GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
+    GGML_API struct ggml_cgraph * ggml_graph_dup       (struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API void                 ggml_graph_cpy       (struct ggml_cgraph * src, struct ggml_cgraph * dst);
+    GGML_API void                 ggml_graph_reset     (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
+    GGML_API void                 ggml_graph_clear     (struct ggml_cgraph * cgraph);
+
+    GGML_API int                   ggml_graph_size   (struct ggml_cgraph * cgraph);
+    GGML_API struct ggml_tensor *  ggml_graph_node   (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
+    GGML_API struct ggml_tensor ** ggml_graph_nodes  (struct ggml_cgraph * cgraph);
+    GGML_API int                   ggml_graph_n_nodes(struct ggml_cgraph * cgraph);
+
+    GGML_API void   ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+
+    GGML_API size_t ggml_graph_overhead(void);
+    GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
+
+    GGML_API struct ggml_tensor * ggml_graph_get_tensor  (const struct ggml_cgraph * cgraph, const char * name);
+    GGML_API struct ggml_tensor * ggml_graph_get_grad    (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
+    GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
+
+    GGML_API void                 ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
+    GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
+
+    // print info and performance information for the graph
+    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
+
+    // dump the graph into a file using the dot format
+    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
+
+    // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
+    typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
+
+    // Set callback for all future logging events.
+    // If this is not called, or NULL is supplied, everything is output on stderr.
+    GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
+
+    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+
+    //
+    // quantization
+    //
+
+    // - ggml_quantize_init can be called multiple times with the same type
+    //   it will only initialize the quantization tables for the first call or after ggml_quantize_free
+    //   automatically called by ggml_quantize_chunk for convenience
+    //
+    // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
+    //   call this at the end of the program to avoid memory leaks
+    //
+    // note: these are thread-safe
+    //
+    GGML_API void ggml_quantize_init(enum ggml_type type);
+    GGML_API void ggml_quantize_free(void);
+
+    // some quantization type cannot be used without an importance matrix
+    GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
+
+    // calls ggml_quantize_init internally (i.e. can allocate memory)
+    GGML_API size_t ggml_quantize_chunk(
+            enum ggml_type   type,
+               const float * src,
+                      void * dst,
+                   int64_t   start,
+                   int64_t   nrows,
+                   int64_t   n_per_row,
+               const float * imatrix);
+
+    //
+    // gguf
+    //
+
+    enum gguf_type {
+        GGUF_TYPE_UINT8   = 0,
+        GGUF_TYPE_INT8    = 1,
+        GGUF_TYPE_UINT16  = 2,
+        GGUF_TYPE_INT16   = 3,
+        GGUF_TYPE_UINT32  = 4,
+        GGUF_TYPE_INT32   = 5,
+        GGUF_TYPE_FLOAT32 = 6,
+        GGUF_TYPE_BOOL    = 7,
+        GGUF_TYPE_STRING  = 8,
+        GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
+        GGUF_TYPE_COUNT,       // marks the end of the enum
+    };
+
+    struct gguf_context;
+
+    struct gguf_init_params {
+        bool no_alloc;
+
+        // if not NULL, create a ggml_context and allocate the tensor data in it
+        struct ggml_context ** ctx;
+    };
+
+    GGML_API struct gguf_context * gguf_init_empty(void);
+    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+    GGML_API void gguf_free(struct gguf_context * ctx);
+
+    GGML_API const char * gguf_type_name(enum gguf_type type);
+
+    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
+    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
+
+    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
+    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
+
+    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
+    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
+    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
+    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
+
+    // removes key if it exists
+    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+    // overrides existing values or adds a new one
+    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
+    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
+    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
+    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
+    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
+    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
+    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
+    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
+    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
+    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
+
+    // set or add KV pairs from another context
+    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
+
+    // manage tensor info
+    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
+
+    // writing gguf files can be done in 2 ways:
+    //
+    // - write the entire gguf_context to a binary file in a single pass:
+    //
+    //   gguf_write_to_file(ctx, fname);
+    //
+    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+    //
+    //   FILE * f = fopen(fname, "wb");
+    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
+    //   fwrite(f, ...);
+    //   void * data = gguf_meta_get_meta_data(ctx);
+    //   fseek(f, 0, SEEK_SET);
+    //   fwrite(f, data, gguf_get_meta_size(ctx));
+    //   free(data);
+    //   fclose(f);
+    //
+
+    // write the entire context to a binary file
+    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+#ifdef __cplusplus
+    // restrict not standard in C++
+#    if defined(__GNUC__)
+#        define GGML_RESTRICT __restrict__
+#    elif defined(__clang__)
+#        define GGML_RESTRICT __restrict
+#    elif defined(_MSC_VER)
+#        define GGML_RESTRICT __restrict
+#    else
+#        define GGML_RESTRICT
+#    endif
+#else
+#    define GGML_RESTRICT restrict
+#endif
+    typedef void (*ggml_to_float_t)  (const void  * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+    typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void  * GGML_RESTRICT y, int64_t k);
+
+    struct ggml_type_traits {
+        const char             * type_name;
+        int64_t                  blck_size;
+        int64_t                  blck_size_interleave; // interleave elements in blocks
+        size_t                   type_size;
+        bool                     is_quantized;
+        ggml_to_float_t          to_float;
+        ggml_from_float_t        from_float_ref;
+    };
+
+    GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);
+
+    // ggml threadpool
+    // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend
+    // the goal should be to create an API that other backends can use move everything to the ggml base
+
+    // scheduling priorities
+    enum ggml_sched_priority {
+        GGML_SCHED_PRIO_NORMAL,
+        GGML_SCHED_PRIO_MEDIUM,
+        GGML_SCHED_PRIO_HIGH,
+        GGML_SCHED_PRIO_REALTIME
+    };
+
+    // threadpool params
+    // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults
+    struct ggml_threadpool_params {
+        bool                cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
+        int                 n_threads;                   // number of threads
+        enum ggml_sched_priority prio;                   // thread priority
+        uint32_t            poll;                        // polling level (0 - no polling, 100 - aggressive polling)
+        bool                strict_cpu;                  // strict cpu placement
+        bool                paused;                      // start in paused state
+    };
+
+    struct ggml_threadpool;     // forward declaration, see ggml.c
+
+    typedef struct ggml_threadpool * ggml_threadpool_t;
+
+    GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
+    GGML_API void                          ggml_threadpool_params_init   (struct ggml_threadpool_params * p, int n_threads);
+    GGML_API bool                          ggml_threadpool_params_match  (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/llama/grammar/grammar_test.go b/llama/grammar/grammar_test.go
new file mode 100644
index 000000000..373652ab4
--- /dev/null
+++ b/llama/grammar/grammar_test.go
@@ -0,0 +1,107 @@
+package grammar
+
+import (
+	"bufio"
+	"bytes"
+	"strings"
+	"testing"
+
+	"github.com/ollama/ollama/llama"
+)
+
+// https://github.com/ollama/ollama/issues/7978
+const issue7978JSONSchema = `{
+  "type": "object",
+  "properties": {
+    "steps": {
+      "type": "array",
+      "items": {
+        "type": "object",
+        "properties": {
+          "explanation": { "type": "string" },
+          "output": { "type": "string" },
+          "nested": {
+            "type": "object",
+            "properties": {
+              "deep": { "type": "string" }
+            }
+          }
+        },
+        "required": ["explanation", "output"],
+        "additionalProperties": false
+      }
+    },
+    "final_answer": { "type": "string" },
+    "01_numbered_key": { "type": "string" },
+    "numbers": {
+      "type": "array",
+      "items": { "type": "number" }
+    },
+    "booleans": {
+      "type": "array", 
+      "items": { "type": "boolean" }
+    },
+    "mixed": {
+      "type": "array",
+      "items": {
+        "oneOf": [
+          { "type": "string" },
+          { "type": "number" },
+          { "type": "boolean" }
+        ]
+      }
+    }
+  },
+  "required": ["steps", "final_answer"],
+  "additionalProperties": false
+}`
+
+func TestIssue7978(t *testing.T) {
+	g := llama.SchemaToGrammar([]byte(issue7978JSONSchema))
+	if g == nil {
+		t.Fatal("failed to convert JSON schema to grammar")
+	}
+
+	t.Logf("grammar:\n%s", g)
+	t.Log()
+
+	var got string
+	s := bufio.NewScanner(bytes.NewReader(g))
+	for s.Scan() {
+		line := strings.TrimSpace(s.Text())
+		step, _, _ := strings.Cut(line, " ::= ")
+		step = strings.TrimSpace(step)
+		if step == "root" {
+			got = line
+		}
+	}
+
+	want := `root ::= "{" space steps-kv "," space final-answer-kv ( "," space ( 01-numbered-key-kv 01-numbered-key-rest | numbers-kv numbers-rest | booleans-kv booleans-rest | mixed-kv ) )? "}" space`
+	if got != want {
+		t.Errorf("root =\n%qwant:\n%q", got, want)
+	}
+}
+
+func TestSchemaToGrammer(t *testing.T) {
+	cases := []struct {
+		schema string
+		prefix []byte // nil is check as nil
+	}{
+		{`invalid`, nil},
+
+		// Simple heuristic/smoke test
+		{`{"type":"object"}`, []byte("root ::= object")},
+	}
+
+	for _, c := range cases {
+		t.Run("x", func(t *testing.T) {
+			g := llama.SchemaToGrammar([]byte(c.schema))
+			if c.prefix == nil && g != nil {
+				t.Fatalf("grammar = %v, want nil", g)
+			}
+			if !bytes.HasPrefix(g, c.prefix) {
+				t.Errorf("grammar = %q, want %q", g, c.prefix)
+			}
+		})
+	}
+}
diff --git a/llama/json-schema-to-grammar.cpp b/llama/json-schema-to-grammar.cpp
new file mode 100644
index 000000000..cc870f9f3
--- /dev/null
+++ b/llama/json-schema-to-grammar.cpp
@@ -0,0 +1,1071 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#include "json-schema-to-grammar.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+using json = nlohmann::ordered_json;
+
+template 
+static std::string join(Iterator begin, Iterator end, const std::string & separator);
+
+static std::string repeat(const std::string & str, size_t n);
+
+static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
+    auto has_max = max_items != std::numeric_limits::max();
+
+    if (min_items == 0 && max_items == 1) {
+        return item_rule + "?";
+    }
+
+    if (separator_rule.empty()) {
+        if (min_items == 1 && !has_max) {
+            return item_rule + "+";
+        } else if (min_items == 0 && !has_max) {
+            return item_rule + "*";
+        } else {
+            return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}";
+        }
+    }
+
+    auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items);
+    if (min_items == 0) {
+        result = "(" + result + ")?";
+    }
+    return result;
+}
+
+/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
+class string_view {
+    const std::string & _str;
+    const size_t _start;
+    const size_t _end;
+public:
+    string_view(const std::string & str, size_t start = 0, size_t end  = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
+
+    size_t size() const {
+        return _end - _start;
+    }
+
+    size_t length() const {
+        return size();
+    }
+
+    operator std::string() const {
+        return str();
+    }
+
+    std::string str() const {
+        return _str.substr(_start, _end - _start);
+    }
+
+    string_view substr(size_t pos, size_t len = std::string::npos) const {
+        return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
+    }
+
+    char operator[](size_t pos) const {
+        auto index = _start + pos;
+        if (index >= _end) {
+            throw std::out_of_range("string_view index out of range");
+        }
+        return _str[_start + pos];
+    }
+
+    bool operator==(const string_view & other) const {
+        std::string this_str = *this;
+        std::string other_str = other;
+        return this_str == other_str;
+    }
+};
+
+static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
+    auto has_min = min_value != std::numeric_limits::min();
+    auto has_max = max_value != std::numeric_limits::max();
+
+    auto digit_range = [&](char from, char to) {
+        out << "[";
+        if (from == to) {
+            out << from;
+        } else {
+            out << from << "-" << to;
+        }
+        out << "]";
+    };
+    auto more_digits = [&](int min_digits, int max_digits) {
+        out << "[0-9]";
+        if (min_digits == max_digits && min_digits == 1) {
+            return;
+        }
+        out << "{";
+        out << min_digits;
+        if (max_digits != min_digits) {
+            out << ",";
+            if (max_digits != std::numeric_limits::max()) {
+                out << max_digits;
+            }
+        }
+        out << "}";
+    };
+    std::function uniform_range =
+        [&](const string_view & from, const string_view & to) {
+            size_t i = 0;
+            while (i < from.length() && i < to.length() && from[i] == to[i]) {
+                i++;
+            }
+            if (i > 0) {
+                out << "\"" << from.substr(0, i).str() << "\"";
+            }
+            if (i < from.length() && i < to.length()) {
+                if (i > 0) {
+                    out << " ";
+                }
+                auto sub_len = from.length() - i - 1;
+                if (sub_len > 0) {
+                    auto from_sub = from.substr(i + 1);
+                    auto to_sub = to.substr(i + 1);
+                    auto sub_zeros = repeat("0", sub_len);
+                    auto sub_nines = repeat("9", sub_len);
+
+                    auto to_reached = false;
+                    out << "(";
+                    if (from_sub == sub_zeros) {
+                        digit_range(from[i], to[i] - 1);
+                        out << " ";
+                        more_digits(sub_len, sub_len);
+                    } else {
+                        out << "[" << from[i] << "] ";
+                        out << "(";
+                        uniform_range(from_sub, sub_nines);
+                        out << ")";
+                        if (from[i] < to[i] - 1) {
+                            out << " | ";
+                            if (to_sub == sub_nines) {
+                                digit_range(from[i] + 1, to[i]);
+                                to_reached = true;
+                            } else {
+                                digit_range(from[i] + 1, to[i] - 1);
+                            }
+                            out << " ";
+                            more_digits(sub_len, sub_len);
+                        }
+                    }
+                    if (!to_reached) {
+                        out << " | ";
+                        digit_range(to[i], to[i]);
+                        out << " ";
+                        uniform_range(sub_zeros, to_sub);
+                    }
+                    out << ")";
+                } else {
+                    out << "[" << from[i] << "-" << to[i] << "]";
+                }
+            }
+        };
+
+    if (has_min && has_max) {
+        if (min_value < 0 && max_value < 0) {
+            out << "\"-\" (";
+            _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
+            out << ")";
+            return;
+        }
+
+        if (min_value < 0) {
+            out << "\"-\" (";
+            _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
+            out << ") | ";
+            min_value = 0;
+        }
+
+        auto min_s = std::to_string(min_value);
+        auto max_s = std::to_string(max_value);
+        auto min_digits = min_s.length();
+        auto max_digits = max_s.length();
+
+        for (auto digits = min_digits; digits < max_digits; digits++) {
+            uniform_range(min_s, repeat("9", digits));
+            min_s = "1" + repeat("0", digits);
+            out << " | ";
+        }
+        uniform_range(min_s, max_s);
+        return;
+    }
+
+    auto less_decimals = std::max(decimals_left - 1, 1);
+
+    if (has_min) {
+        if (min_value < 0) {
+            out << "\"-\" (";
+            _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false);
+            out << ") | [0] | [1-9] ";
+            more_digits(0, decimals_left - 1);
+        } else if (min_value == 0) {
+            if (top_level) {
+                out << "[0] | [1-9] ";
+                more_digits(0, less_decimals);
+            } else {
+                more_digits(1, decimals_left);
+            }
+        } else if (min_value <= 9) {
+            char c = '0' + min_value;
+            auto range_start = top_level ? '1' : '0';
+            if (c > range_start) {
+                digit_range(range_start, c - 1);
+                out << " ";
+                more_digits(1, less_decimals);
+                out << " | ";
+            }
+            digit_range(c, '9');
+            out << " ";
+            more_digits(0, less_decimals);
+        } else {
+            auto min_s = std::to_string(min_value);
+            auto len = min_s.length();
+            auto c = min_s[0];
+
+            if (c > '1') {
+                digit_range(top_level ? '1' : '0', c - 1);
+                out << " ";
+                more_digits(len, less_decimals);
+                out << " | ";
+            }
+            digit_range(c, c);
+            out << " (";
+            _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false);
+            out << ")";
+            if (c < '9') {
+                out << " | ";
+                digit_range(c + 1, '9');
+                out << " ";
+                more_digits(len - 1, less_decimals);
+            }
+        }
+        return;
+    }
+
+    if (has_max) {
+        if (max_value >= 0) {
+            if (top_level) {
+                out << "\"-\" [1-9] ";
+                more_digits(0, less_decimals);
+                out << " | ";
+            }
+            _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
+        } else {
+            out << "\"-\" (";
+            _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false);
+            out << ")";
+        }
+        return;
+    }
+
+    throw std::runtime_error("At least one of min_value or max_value must be set");
+}
+
+const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
+
+struct BuiltinRule {
+    std::string content;
+    std::vector deps;
+};
+
+std::unordered_map PRIMITIVE_RULES = {
+    {"boolean", {"(\"true\" | \"false\") space", {}}},
+    {"decimal-part", {"[0-9]{1,16}", {}}},
+    {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}},
+    {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}},
+    {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}},
+    {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}},
+    {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}},
+    {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}},
+    {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}},
+    {"char",   {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}},
+    {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}},
+    {"null", {"\"null\" space", {}}},
+};
+
+std::unordered_map STRING_FORMAT_RULES = {
+    {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}},
+    {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}},
+    {"date-time", {"date \"T\" time", {"date", "time"}}},
+    {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}},
+    {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}},
+    {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}}
+};
+
+static bool is_reserved_name(const std::string & name) {
+    static std::unordered_set RESERVED_NAMES;
+    if (RESERVED_NAMES.empty()) {
+        RESERVED_NAMES.insert("root");
+        for (const auto &p : PRIMITIVE_RULES) RESERVED_NAMES.insert(p.first);
+        for (const auto &p : STRING_FORMAT_RULES) RESERVED_NAMES.insert(p.first);
+    }
+    return RESERVED_NAMES.find(name) != RESERVED_NAMES.end();
+}
+
+std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
+std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
+std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
+std::unordered_map GRAMMAR_LITERAL_ESCAPES = {
+    {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
+};
+
+std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
+std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
+
+template 
+std::string join(Iterator begin, Iterator end, const std::string & separator) {
+    std::ostringstream result;
+    if (begin != end) {
+        result << *begin;
+        for (Iterator it = begin + 1; it != end; ++it) {
+            result << separator << *it;
+        }
+    }
+    return result.str();
+}
+
+static std::vector split(const std::string & str, const std::string & delimiter) {
+    std::vector tokens;
+    size_t start = 0;
+    size_t end = str.find(delimiter);
+
+    while (end != std::string::npos) {
+        tokens.push_back(str.substr(start, end - start));
+        start = end + delimiter.length();
+        end = str.find(delimiter, start);
+    }
+
+    tokens.push_back(str.substr(start));
+
+    return tokens;
+}
+
+static std::string repeat(const std::string & str, size_t n) {
+    if (n == 0) {
+        return "";
+    }
+
+    std::string result;
+    result.reserve(str.length() * n);
+
+    for (size_t i = 0; i < n; ++i) {
+        result += str;
+    }
+
+    return result;
+}
+
+static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) {
+    std::smatch match;
+    std::string result;
+
+    std::string::const_iterator searchStart(input.cbegin());
+    std::string::const_iterator searchEnd(input.cend());
+
+    while (std::regex_search(searchStart, searchEnd, match, regex)) {
+        result.append(searchStart, searchStart + match.position());
+        result.append(replacement(match));
+        searchStart = match.suffix().first;
+    }
+
+    result.append(searchStart, searchEnd);
+
+    return result;
+}
+
+static std::string format_literal(const std::string & literal) {
+    std::string escaped = replacePattern(literal, GRAMMAR_LITERAL_ESCAPE_RE, [&](const std::smatch & match) {
+        char c = match.str()[0];
+        return GRAMMAR_LITERAL_ESCAPES.at(c);
+    });
+    return "\"" + escaped + "\"";
+}
+
+class SchemaConverter {
+private:
+    std::function _fetch_json;
+    bool _dotall;
+    std::unordered_map _rules;
+    std::unordered_map _refs;
+    std::unordered_set _refs_being_resolved;
+    std::vector _errors;
+    std::vector _warnings;
+
+    std::string _add_rule(const std::string & name, const std::string & rule) {
+        std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
+        if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) {
+            _rules[esc_name] = rule;
+            return esc_name;
+        } else {
+            int i = 0;
+            while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) {
+                i++;
+            }
+            std::string key = esc_name + std::to_string(i);
+            _rules[key] = rule;
+            return key;
+        }
+    }
+
+    std::string _generate_union_rule(const std::string & name, const std::vector & alt_schemas) {
+        std::vector rules;
+        for (size_t i = 0; i < alt_schemas.size(); i++) {
+            rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
+        }
+        return join(rules.begin(), rules.end(), " | ");
+    }
+
+    std::string _visit_pattern(const std::string & pattern, const std::string & name) {
+        if (!(pattern.front() == '^' && pattern.back() == '$')) {
+            _errors.push_back("Pattern must start with '^' and end with '$'");
+            return "";
+        }
+        std::string sub_pattern = pattern.substr(1, pattern.length() - 2);
+        std::unordered_map sub_rule_ids;
+
+        size_t i = 0;
+        size_t length = sub_pattern.length();
+
+        using literal_or_rule = std::pair;
+        auto to_rule = [&](const literal_or_rule & ls) {
+            auto is_literal = ls.second;
+            auto s = ls.first;
+            return is_literal ? "\"" + s + "\"" : s;
+        };
+        std::function transform = [&]() -> literal_or_rule {
+            size_t start = i;
+            std::vector seq;
+
+            auto get_dot = [&]() {
+                std::string rule;
+                if (_dotall) {
+                    rule = "[\\U00000000-\\U0010FFFF]";
+                } else {
+                    rule = "[^\\x0A\\x0D]";
+                }
+                return _add_rule("dot", rule);
+            };
+
+            // Joins the sequence, merging consecutive literals together.
+            auto join_seq = [&]() {
+                std::vector ret;
+
+                std::string literal;
+                auto flush_literal = [&]() {
+                    if (literal.empty()) {
+                        return false;
+                    }
+                    ret.emplace_back(literal, true);
+                    literal.clear();
+                    return true;
+                };
+
+                for (const auto & item : seq) {
+                    auto is_literal = item.second;
+                    if (is_literal) {
+                        literal += item.first;
+                    } else {
+                        flush_literal();
+                        ret.push_back(item);
+                    }
+                }
+                flush_literal();
+
+                std::vector results;
+                for (const auto & item : ret) {
+                    results.push_back(to_rule(item));
+                }
+                return std::make_pair(join(results.begin(), results.end(), " "), false);
+            };
+
+            while (i < length) {
+                char c = sub_pattern[i];
+                if (c == '.') {
+                    seq.emplace_back(get_dot(), false);
+                    i++;
+                } else if (c == '(') {
+                    i++;
+                    if (i < length) {
+                        if (sub_pattern[i] == '?') {
+                            _warnings.push_back("Unsupported pattern syntax");
+                        }
+                    }
+                    seq.emplace_back("(" + to_rule(transform()) + ")", false);
+                } else if (c == ')') {
+                    i++;
+                    if (start > 0 && sub_pattern[start - 1] != '(') {
+                        _errors.push_back("Unbalanced parentheses");
+                    }
+                    return join_seq();
+                } else if (c == '[') {
+                    std::string square_brackets = std::string(1, c);
+                    i++;
+                    while (i < length && sub_pattern[i] != ']') {
+                        if (sub_pattern[i] == '\\') {
+                            square_brackets += sub_pattern.substr(i, 2);
+                            i += 2;
+                        } else {
+                            square_brackets += sub_pattern[i];
+                            i++;
+                        }
+                    }
+                    if (i >= length) {
+                        _errors.push_back("Unbalanced square brackets");
+                    }
+                    square_brackets += ']';
+                    i++;
+                    seq.emplace_back(square_brackets, false);
+                } else if (c == '|') {
+                    seq.emplace_back("|", false);
+                    i++;
+                } else if (c == '*' || c == '+' || c == '?') {
+                    seq.back() = std::make_pair(to_rule(seq.back()) + c, false);
+                    i++;
+                } else if (c == '{') {
+                    std::string curly_brackets = std::string(1, c);
+                    i++;
+                    while (i < length && sub_pattern[i] != '}') {
+                        curly_brackets += sub_pattern[i];
+                        i++;
+                    }
+                    if (i >= length) {
+                        _errors.push_back("Unbalanced curly brackets");
+                    }
+                    curly_brackets += '}';
+                    i++;
+                    auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
+                    int min_times = 0;
+                    int max_times = std::numeric_limits::max();
+                    try {
+                        if (nums.size() == 1) {
+                            min_times = max_times = std::stoi(nums[0]);
+                        } else if (nums.size() != 2) {
+                            _errors.push_back("Wrong number of values in curly brackets");
+                        } else {
+                            if (!nums[0].empty()) {
+                                min_times = std::stoi(nums[0]);
+                            }
+                            if (!nums[1].empty()) {
+                                max_times = std::stoi(nums[1]);
+                            }
+                        }
+                    } catch (const std::invalid_argument & e) {
+                        _errors.push_back("Invalid number in curly brackets");
+                        return std::make_pair("", false);
+                    }
+                    auto &last = seq.back();
+                    auto &sub = last.first;
+                    auto sub_is_literal = last.second;
+
+                    if (!sub_is_literal) {
+                        std::string & sub_id = sub_rule_ids[sub];
+                        if (sub_id.empty()) {
+                            sub_id = _add_rule(name + "-" + std::to_string(sub_rule_ids.size()), sub);
+                        }
+                        sub = sub_id;
+                    }
+                    seq.back().first = build_repetition(
+                        sub_is_literal ? "\"" + sub + "\"" : sub,
+                        min_times,
+                        max_times,
+                        ""
+                    );
+                    seq.back().second = false;
+                } else {
+                    std::string literal;
+                    auto is_non_literal = [&](char c) {
+                        return NON_LITERAL_SET.find(c) != NON_LITERAL_SET.end();
+                    };
+                    while (i < length) {
+                        if (sub_pattern[i] == '\\' && i < length - 1) {
+                            char next = sub_pattern[i + 1];
+                            if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.find(next) != ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.end()) {
+                                i++;
+                                literal += sub_pattern[i];
+                                i++;
+                            } else {
+                                literal += sub_pattern.substr(i, 2);
+                                i += 2;
+                            }
+                        } else if (sub_pattern[i] == '"') {
+                            literal += "\\\"";
+                            i++;
+                        } else if (!is_non_literal(sub_pattern[i]) &&
+                                (i == length - 1 || literal.empty() || sub_pattern[i + 1] == '.' || !is_non_literal(sub_pattern[i + 1]))) {
+                            literal += sub_pattern[i];
+                            i++;
+                        } else {
+                            break;
+                        }
+                    }
+                    if (!literal.empty()) {
+                        seq.emplace_back(literal, true);
+                    }
+                }
+            }
+            return join_seq();
+        };
+        return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space");
+    }
+
+    /*
+        Returns a rule that matches a JSON string that is none of the provided strings
+
+        not_strings({"a"})
+            -> ["] ( [a] char+ | [^"a] char* )? ["] space
+        not_strings({"and", "also"})
+            -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
+    */
+    std::string _not_strings(const std::vector & strings) {
+
+        struct TrieNode {
+            std::map children;
+            bool is_end_of_string;
+
+            TrieNode() : is_end_of_string(false) {}
+
+            void insert(const std::string & string) {
+                auto node = this;
+                for (char c : string) {
+                    node = &node->children[c];
+                }
+                node->is_end_of_string = true;
+            }
+        };
+
+        TrieNode trie;
+        for (const auto & s : strings) {
+            trie.insert(s);
+        }
+
+        std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
+        std::ostringstream out;
+        out << "[\"] ( ";
+        std::function visit = [&](const TrieNode & node) {
+            std::ostringstream rejects;
+            auto first = true;
+            for (const auto & kv : node.children) {
+                rejects << kv.first;
+                if (first) {
+                    first = false;
+                } else {
+                    out << " | ";
+                }
+                out << "[" << kv.first << "]";
+                if (!kv.second.children.empty()) {
+                    out << " (";
+                    visit(kv.second);
+                    out << ")";
+                } else if (kv.second.is_end_of_string) {
+                    out << " " << char_rule << "+";
+                }
+            }
+            if (!node.children.empty()) {
+                if (!first) {
+                    out << " | ";
+                }
+                out << "[^\"" << rejects.str() << "] " << char_rule << "*";
+            }
+        };
+        visit(trie);
+
+        out << " )";
+        if (!trie.is_end_of_string) {
+            out << "?";
+        }
+        out << " [\"] space";
+        return out.str();
+    }
+
+    std::string _resolve_ref(const std::string & ref) {
+        std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
+        if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
+            _refs_being_resolved.insert(ref);
+            json resolved = _refs[ref];
+            ref_name = visit(resolved, ref_name);
+            _refs_being_resolved.erase(ref);
+        }
+        return ref_name;
+    }
+
+    std::string _build_object_rule(
+        const std::vector> & properties,
+        const std::unordered_set & required,
+        const std::string & name,
+        const json & additional_properties)
+    {
+        std::vector required_props;
+        std::vector optional_props;
+        std::unordered_map prop_kv_rule_names;
+        std::vector prop_names;
+        for (const auto & kv : properties) {
+            const auto &prop_name = kv.first;
+            const auto &prop_schema = kv.second;
+
+            std::string prop_rule_name = visit(prop_schema, name + (name.empty() ? "" : "-") + prop_name);
+            prop_kv_rule_names[prop_name] = _add_rule(
+                name + (name.empty() ? "" : "-") + prop_name + "-kv",
+                format_literal(json(prop_name).dump()) + " space \":\" space " + prop_rule_name
+            );
+            if (required.find(prop_name) != required.end()) {
+                required_props.push_back(prop_name);
+            } else {
+                optional_props.push_back(prop_name);
+            }
+            prop_names.push_back(prop_name);
+        }
+        if ((additional_properties.is_boolean() && additional_properties.get()) || additional_properties.is_object()) {
+            std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
+            std::string value_rule =
+                additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
+                : _add_primitive("value", PRIMITIVE_RULES.at("value"));
+
+            auto key_rule =
+                prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
+                : _add_rule(sub_name + "-k", _not_strings(prop_names));
+            std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
+            prop_kv_rule_names["*"] = kv_rule;
+            optional_props.push_back("*");
+        }
+
+        std::string rule = "\"{\" space ";
+        for (size_t i = 0; i < required_props.size(); i++) {
+            if (i > 0) {
+                rule += " \",\" space ";
+            }
+            rule += prop_kv_rule_names[required_props[i]];
+        }
+
+        if (!optional_props.empty()) {
+            rule += " (";
+            if (!required_props.empty()) {
+                rule += " \",\" space ( ";
+            }
+
+            std::function &, bool)> get_recursive_refs = [&](const std::vector & ks, bool first_is_optional) {
+                std::string res;
+                if (ks.empty()) {
+                    return res;
+                }
+                std::string k = ks[0];
+                std::string kv_rule_name = prop_kv_rule_names[k];
+                std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
+                if (first_is_optional) {
+                    res = comma_ref + (k == "*" ? "*" : "?");
+                } else {
+                    res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
+                }
+                if (ks.size() > 1) {
+                    res += " " + _add_rule(
+                        name + (name.empty() ? "" : "-") + k + "-rest",
+                        get_recursive_refs(std::vector(ks.begin() + 1, ks.end()), true)
+                    );
+                }
+                return res;
+            };
+
+            for (size_t i = 0; i < optional_props.size(); i++) {
+                if (i > 0) {
+                    rule += " | ";
+                }
+                rule += get_recursive_refs(std::vector(optional_props.begin() + i, optional_props.end()), false);
+            }
+            if (!required_props.empty()) {
+                rule += " )";
+            }
+            rule += " )?";
+        }
+
+        rule += " \"}\" space";
+
+        return rule;
+    }
+
+    std::string _add_primitive(const std::string & name, const BuiltinRule & rule) {
+        auto n = _add_rule(name, rule.content);
+        for (const auto & dep : rule.deps) {
+            BuiltinRule dep_rule;
+            auto it = PRIMITIVE_RULES.find(dep);
+            if (it == PRIMITIVE_RULES.end()) {
+                it = STRING_FORMAT_RULES.find(dep);
+                if (it == STRING_FORMAT_RULES.end()) {
+                    _errors.push_back("Rule " + dep + " not known");
+                    continue;
+                }
+            }
+            if (_rules.find(dep) == _rules.end()) {
+                _add_primitive(dep, it->second);
+            }
+        }
+        return n;
+    }
+
+public:
+    SchemaConverter(
+        const std::function & fetch_json,
+        bool dotall)
+          : _fetch_json(fetch_json), _dotall(dotall)
+    {
+        _rules["space"] = SPACE_RULE;
+    }
+
+    void resolve_refs(json & schema, const std::string & url) {
+        /*
+        * Resolves all $ref fields in the given schema, fetching any remote schemas,
+        * replacing each $ref with absolute reference URL and populates _refs with the
+        * respective referenced (sub)schema dictionaries.
+        */
+        std::function visit_refs = [&](json & n) {
+            if (n.is_array()) {
+                for (auto & x : n) {
+                    visit_refs(x);
+                }
+            } else if (n.is_object()) {
+                if (n.contains("$ref")) {
+                    std::string ref = n["$ref"];
+                    if (_refs.find(ref) == _refs.end()) {
+                        json target;
+                        if (ref.find("https://") == 0) {
+                            std::string base_url = ref.substr(0, ref.find('#'));
+                            auto it = _refs.find(base_url);
+                            if (it != _refs.end()) {
+                                target = it->second;
+                            } else {
+                                // Fetch the referenced schema and resolve its refs
+                                auto referenced = _fetch_json(ref);
+                                resolve_refs(referenced, base_url);
+                                _refs[base_url] = referenced;
+                            }
+                            if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
+                                return;
+                            }
+                        } else if (ref.find("#/") == 0) {
+                            target = schema;
+                            n["$ref"] = url + ref;
+                            ref = url + ref;
+                        } else {
+                            _errors.push_back("Unsupported ref: " + ref);
+                            return;
+                        }
+                        std::string pointer = ref.substr(ref.find('#') + 1);
+                        std::vector tokens = split(pointer, "/");
+                        for (size_t i = 1; i < tokens.size(); ++i) {
+                            std::string sel = tokens[i];
+                            if (target.is_null() || !target.contains(sel)) {
+                                _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
+                                return;
+                            }
+                            target = target[sel];
+                        }
+                        _refs[ref] = target;
+                    }
+                } else {
+                    for (auto & kv : n.items()) {
+                        visit_refs(kv.value());
+                    }
+                }
+            }
+        };
+
+        visit_refs(schema);
+    }
+
+    std::string _generate_constant_rule(const json & value) {
+        return format_literal(value.dump());
+    }
+
+    std::string visit(const json & schema, const std::string & name) {
+        json schema_type = schema.contains("type") ? schema["type"] : json();
+        std::string schema_format = schema.contains("format") ? schema["format"].get() : "";
+        std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
+
+        if (schema.contains("$ref")) {
+            return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
+        } else if (schema.contains("oneOf") || schema.contains("anyOf")) {
+            std::vector alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get>() : schema["anyOf"].get>();
+            return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
+        } else if (schema_type.is_array()) {
+            std::vector schema_types;
+            for (const auto & t : schema_type) {
+                json schema_copy(schema);
+                schema_copy["type"] = t;
+                schema_types.push_back(schema_copy);
+            }
+            return _add_rule(rule_name, _generate_union_rule(name, schema_types));
+        } else if (schema.contains("const")) {
+            return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
+        } else if (schema.contains("enum")) {
+            std::vector enum_values;
+            for (const auto & v : schema["enum"]) {
+                enum_values.push_back(_generate_constant_rule(v));
+            }
+            return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
+        } else if ((schema_type.is_null() || schema_type == "object")
+                && (schema.contains("properties") ||
+                    (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
+            std::unordered_set required;
+            if (schema.contains("required") && schema["required"].is_array()) {
+                for (const auto & item : schema["required"]) {
+                    if (item.is_string()) {
+                        required.insert(item.get());
+                    }
+                }
+            }
+            std::vector> properties;
+            if (schema.contains("properties")) {
+                for (const auto & prop : schema["properties"].items()) {
+                    properties.emplace_back(prop.key(), prop.value());
+                }
+            }
+            return _add_rule(rule_name,
+                _build_object_rule(
+                    properties, required, name,
+                    schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
+        } else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
+            std::unordered_set required;
+            std::vector> properties;
+            std::string hybrid_name = name;
+            std::function add_component = [&](const json & comp_schema, bool is_required) {
+                if (comp_schema.contains("$ref")) {
+                    add_component(_refs[comp_schema["$ref"]], is_required);
+                } else if (comp_schema.contains("properties")) {
+                    for (const auto & prop : comp_schema["properties"].items()) {
+                        properties.emplace_back(prop.key(), prop.value());
+                        if (is_required) {
+                            required.insert(prop.key());
+                        }
+                    }
+                } else {
+                  // todo warning
+                }
+            };
+            for (auto & t : schema["allOf"]) {
+                if (t.contains("anyOf")) {
+                    for (auto & tt : t["anyOf"]) {
+                        add_component(tt, false);
+                    }
+                } else {
+                    add_component(t, true);
+                }
+            }
+            return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
+        } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
+            json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
+            if (items.is_array()) {
+                std::string rule = "\"[\" space ";
+                for (size_t i = 0; i < items.size(); i++) {
+                    if (i > 0) {
+                        rule += " \",\" space ";
+                    }
+                    rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i));
+                }
+                rule += " \"]\" space";
+                return _add_rule(rule_name, rule);
+            } else {
+                std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item");
+                int min_items = schema.contains("minItems") ? schema["minItems"].get() : 0;
+                json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json();
+                int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max();
+
+                return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space");
+            }
+        } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) {
+            return _visit_pattern(schema["pattern"], rule_name);
+        } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) {
+            return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid"));
+        } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) {
+            auto prim_name = schema_format + "-string";
+            return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name)));
+        } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) {
+            std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
+            int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0;
+            int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max();
+            return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
+        } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
+            int min_value = std::numeric_limits::min();
+            int max_value = std::numeric_limits::max();
+            if (schema.contains("minimum")) {
+                min_value = schema["minimum"].get();
+            } else if (schema.contains("exclusiveMinimum")) {
+                min_value = schema["exclusiveMinimum"].get() + 1;
+            }
+            if (schema.contains("maximum")) {
+                max_value = schema["maximum"].get();
+            } else if (schema.contains("exclusiveMaximum")) {
+                max_value = schema["exclusiveMaximum"].get() - 1;
+            }
+            std::stringstream out;
+            out << "(";
+            _build_min_max_int(min_value, max_value, out);
+            out << ") space";
+            return _add_rule(rule_name, out.str());
+        } else if (schema.empty() || schema_type == "object") {
+            return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
+        } else {
+            if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) {
+                _errors.push_back("Unrecognized schema: " + schema.dump());
+                return "";
+            }
+            // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
+            return _add_primitive(rule_name == "root" ? "root" : schema_type.get(), PRIMITIVE_RULES.at(schema_type.get()));
+        }
+    }
+
+    void check_errors() {
+        if (!_errors.empty()) {
+            throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
+        }
+        if (!_warnings.empty()) {
+            fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
+        }
+    }
+
+    std::string format_grammar() {
+        std::stringstream ss;
+        for (const auto & kv : _rules) {
+            ss << kv.first << " ::= " << kv.second << std::endl;
+        }
+        return ss.str();
+    }
+};
+
+std::string json_schema_to_grammar(const json & schema) {
+    SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
+    auto copy = schema;
+    converter.resolve_refs(copy, "input");
+    converter.visit(copy, "");
+    converter.check_errors();
+    return converter.format_grammar();
+}
diff --git a/llama/json-schema-to-grammar.h b/llama/json-schema-to-grammar.h
new file mode 100644
index 000000000..39b451cab
--- /dev/null
+++ b/llama/json-schema-to-grammar.h
@@ -0,0 +1,34 @@
+/**
+ * llama.cpp - commit 46e3556e01b824e52395fb050b29804b6cff2a7c - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml.h"
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
+#include "json.hpp"
+
+std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
diff --git a/llama/json.hpp b/llama/json.hpp
new file mode 100644
index 000000000..a858728c4
--- /dev/null
+++ b/llama/json.hpp
@@ -0,0 +1,24766 @@
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+/****************************************************************************\
+ * Note on documentation: The source files contain links to the online      *
+ * documentation of the public API at https://json.nlohmann.me. This URL    *
+ * contains the most recent documentation and should also be applicable to  *
+ * previous versions; documentation for deprecated functions is not         *
+ * removed, but marked deprecated. See "Generate documentation" section in  *
+ * file docs/README.md.                                                     *
+\****************************************************************************/
+
+#ifndef INCLUDE_NLOHMANN_JSON_HPP_
+#define INCLUDE_NLOHMANN_JSON_HPP_
+
+#include  // all_of, find, for_each
+#include  // nullptr_t, ptrdiff_t, size_t
+#include  // hash, less
+#include  // initializer_list
+#ifndef JSON_NO_IO
+    #include  // istream, ostream
+#endif  // JSON_NO_IO
+#include  // random_access_iterator_tag
+#include  // unique_ptr
+#include  // string, stoi, to_string
+#include  // declval, forward, move, pair, swap
+#include  // vector
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include 
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+// This file contains all macro definitions affecting or depending on the ABI
+
+#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK
+    #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH)
+        #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3
+            #warning "Already included a different version of the library!"
+        #endif
+    #endif
+#endif
+
+#define NLOHMANN_JSON_VERSION_MAJOR 3   // NOLINT(modernize-macro-to-enum)
+#define NLOHMANN_JSON_VERSION_MINOR 11  // NOLINT(modernize-macro-to-enum)
+#define NLOHMANN_JSON_VERSION_PATCH 3   // NOLINT(modernize-macro-to-enum)
+
+#ifndef JSON_DIAGNOSTICS
+    #define JSON_DIAGNOSTICS 0
+#endif
+
+#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
+    #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0
+#endif
+
+#if JSON_DIAGNOSTICS
+    #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag
+#else
+    #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS
+#endif
+
+#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON
+    #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp
+#else
+    #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON
+#endif
+
+#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION
+    #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0
+#endif
+
+// Construct the namespace ABI tags component
+#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b
+#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \
+    NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b)
+
+#define NLOHMANN_JSON_ABI_TAGS                                       \
+    NLOHMANN_JSON_ABI_TAGS_CONCAT(                                   \
+            NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS,                       \
+            NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON)
+
+// Construct the namespace version component
+#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \
+    _v ## major ## _ ## minor ## _ ## patch
+#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \
+    NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch)
+
+#if NLOHMANN_JSON_NAMESPACE_NO_VERSION
+#define NLOHMANN_JSON_NAMESPACE_VERSION
+#else
+#define NLOHMANN_JSON_NAMESPACE_VERSION                                 \
+    NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \
+                                           NLOHMANN_JSON_VERSION_MINOR, \
+                                           NLOHMANN_JSON_VERSION_PATCH)
+#endif
+
+// Combine namespace components
+#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b
+#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \
+    NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b)
+
+#ifndef NLOHMANN_JSON_NAMESPACE
+#define NLOHMANN_JSON_NAMESPACE               \
+    nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \
+            NLOHMANN_JSON_ABI_TAGS,           \
+            NLOHMANN_JSON_NAMESPACE_VERSION)
+#endif
+
+#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN
+#define NLOHMANN_JSON_NAMESPACE_BEGIN                \
+    namespace nlohmann                               \
+    {                                                \
+    inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \
+                NLOHMANN_JSON_ABI_TAGS,              \
+                NLOHMANN_JSON_NAMESPACE_VERSION)     \
+    {
+#endif
+
+#ifndef NLOHMANN_JSON_NAMESPACE_END
+#define NLOHMANN_JSON_NAMESPACE_END                                     \
+    }  /* namespace (inline namespace) NOLINT(readability/namespace) */ \
+    }  // namespace nlohmann
+#endif
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // transform
+#include  // array
+#include  // forward_list
+#include  // inserter, front_inserter, end
+#include  // map
+#include  // string
+#include  // tuple, make_tuple
+#include  // is_arithmetic, is_same, is_enum, underlying_type, is_convertible
+#include  // unordered_map
+#include  // pair, declval
+#include  // valarray
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // nullptr_t
+#include  // exception
+#if JSON_DIAGNOSTICS
+    #include  // accumulate
+#endif
+#include  // runtime_error
+#include  // to_string
+#include  // vector
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // array
+#include  // size_t
+#include  // uint8_t
+#include  // string
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // declval, pair
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include 
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+template struct make_void
+{
+    using type = void;
+};
+template using void_t = typename make_void::type;
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+// https://en.cppreference.com/w/cpp/experimental/is_detected
+struct nonesuch
+{
+    nonesuch() = delete;
+    ~nonesuch() = delete;
+    nonesuch(nonesuch const&) = delete;
+    nonesuch(nonesuch const&&) = delete;
+    void operator=(nonesuch const&) = delete;
+    void operator=(nonesuch&&) = delete;
+};
+
+template class Op,
+         class... Args>
+struct detector
+{
+    using value_t = std::false_type;
+    using type = Default;
+};
+
+template class Op, class... Args>
+struct detector>, Op, Args...>
+{
+    using value_t = std::true_type;
+    using type = Op;
+};
+
+template class Op, class... Args>
+using is_detected = typename detector::value_t;
+
+template class Op, class... Args>
+struct is_detected_lazy : is_detected { };
+
+template class Op, class... Args>
+using detected_t = typename detector::type;
+
+template class Op, class... Args>
+using detected_or = detector;
+
+template class Op, class... Args>
+using detected_or_t = typename detected_or::type;
+
+template class Op, class... Args>
+using is_detected_exact = std::is_same>;
+
+template class Op, class... Args>
+using is_detected_convertible =
+    std::is_convertible, To>;
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+
+
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-FileCopyrightText: 2016-2021 Evan Nemerson 
+// SPDX-License-Identifier: MIT
+
+/* Hedley - https://nemequ.github.io/hedley
+ * Created by Evan Nemerson 
+ */
+
+#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 15)
+#if defined(JSON_HEDLEY_VERSION)
+    #undef JSON_HEDLEY_VERSION
+#endif
+#define JSON_HEDLEY_VERSION 15
+
+#if defined(JSON_HEDLEY_STRINGIFY_EX)
+    #undef JSON_HEDLEY_STRINGIFY_EX
+#endif
+#define JSON_HEDLEY_STRINGIFY_EX(x) #x
+
+#if defined(JSON_HEDLEY_STRINGIFY)
+    #undef JSON_HEDLEY_STRINGIFY
+#endif
+#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x)
+
+#if defined(JSON_HEDLEY_CONCAT_EX)
+    #undef JSON_HEDLEY_CONCAT_EX
+#endif
+#define JSON_HEDLEY_CONCAT_EX(a,b) a##b
+
+#if defined(JSON_HEDLEY_CONCAT)
+    #undef JSON_HEDLEY_CONCAT
+#endif
+#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b)
+
+#if defined(JSON_HEDLEY_CONCAT3_EX)
+    #undef JSON_HEDLEY_CONCAT3_EX
+#endif
+#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c
+
+#if defined(JSON_HEDLEY_CONCAT3)
+    #undef JSON_HEDLEY_CONCAT3
+#endif
+#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c)
+
+#if defined(JSON_HEDLEY_VERSION_ENCODE)
+    #undef JSON_HEDLEY_VERSION_ENCODE
+#endif
+#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision))
+
+#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR)
+    #undef JSON_HEDLEY_VERSION_DECODE_MAJOR
+#endif
+#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000)
+
+#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR)
+    #undef JSON_HEDLEY_VERSION_DECODE_MINOR
+#endif
+#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000)
+
+#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION)
+    #undef JSON_HEDLEY_VERSION_DECODE_REVISION
+#endif
+#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000)
+
+#if defined(JSON_HEDLEY_GNUC_VERSION)
+    #undef JSON_HEDLEY_GNUC_VERSION
+#endif
+#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__)
+    #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__)
+#elif defined(__GNUC__)
+    #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK)
+    #undef JSON_HEDLEY_GNUC_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_GNUC_VERSION)
+    #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_MSVC_VERSION)
+    #undef JSON_HEDLEY_MSVC_VERSION
+#endif
+#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) && !defined(__ICL)
+    #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100)
+#elif defined(_MSC_FULL_VER) && !defined(__ICL)
+    #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10)
+#elif defined(_MSC_VER) && !defined(__ICL)
+    #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0)
+#endif
+
+#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK)
+    #undef JSON_HEDLEY_MSVC_VERSION_CHECK
+#endif
+#if !defined(JSON_HEDLEY_MSVC_VERSION)
+    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0)
+#elif defined(_MSC_VER) && (_MSC_VER >= 1400)
+    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch)))
+#elif defined(_MSC_VER) && (_MSC_VER >= 1200)
+    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch)))
+#else
+    #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor)))
+#endif
+
+#if defined(JSON_HEDLEY_INTEL_VERSION)
+    #undef JSON_HEDLEY_INTEL_VERSION
+#endif
+#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && !defined(__ICL)
+    #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE)
+#elif defined(__INTEL_COMPILER) && !defined(__ICL)
+    #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0)
+#endif
+
+#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK)
+    #undef JSON_HEDLEY_INTEL_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_INTEL_VERSION)
+    #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_INTEL_CL_VERSION)
+    #undef JSON_HEDLEY_INTEL_CL_VERSION
+#endif
+#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && defined(__ICL)
+    #define JSON_HEDLEY_INTEL_CL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER, __INTEL_COMPILER_UPDATE, 0)
+#endif
+
+#if defined(JSON_HEDLEY_INTEL_CL_VERSION_CHECK)
+    #undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_INTEL_CL_VERSION)
+    #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_CL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_PGI_VERSION)
+    #undef JSON_HEDLEY_PGI_VERSION
+#endif
+#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__)
+    #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__)
+#endif
+
+#if defined(JSON_HEDLEY_PGI_VERSION_CHECK)
+    #undef JSON_HEDLEY_PGI_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_PGI_VERSION)
+    #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_SUNPRO_VERSION)
+    #undef JSON_HEDLEY_SUNPRO_VERSION
+#endif
+#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000)
+    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10)
+#elif defined(__SUNPRO_C)
+    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf)
+#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000)
+    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10)
+#elif defined(__SUNPRO_CC)
+    #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf)
+#endif
+
+#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK)
+    #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_SUNPRO_VERSION)
+    #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION)
+    #undef JSON_HEDLEY_EMSCRIPTEN_VERSION
+#endif
+#if defined(__EMSCRIPTEN__)
+    #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__)
+#endif
+
+#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK)
+    #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION)
+    #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_ARM_VERSION)
+    #undef JSON_HEDLEY_ARM_VERSION
+#endif
+#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION)
+    #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100)
+#elif defined(__CC_ARM) && defined(__ARMCC_VERSION)
+    #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100)
+#endif
+
+#if defined(JSON_HEDLEY_ARM_VERSION_CHECK)
+    #undef JSON_HEDLEY_ARM_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_ARM_VERSION)
+    #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_IBM_VERSION)
+    #undef JSON_HEDLEY_IBM_VERSION
+#endif
+#if defined(__ibmxl__)
+    #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__)
+#elif defined(__xlC__) && defined(__xlC_ver__)
+    #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff)
+#elif defined(__xlC__)
+    #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0)
+#endif
+
+#if defined(JSON_HEDLEY_IBM_VERSION_CHECK)
+    #undef JSON_HEDLEY_IBM_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_IBM_VERSION)
+    #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_VERSION)
+    #undef JSON_HEDLEY_TI_VERSION
+#endif
+#if \
+    defined(__TI_COMPILER_VERSION__) && \
+    ( \
+      defined(__TMS470__) || defined(__TI_ARM__) || \
+      defined(__MSP430__) || \
+      defined(__TMS320C2000__) \
+    )
+#if (__TI_COMPILER_VERSION__ >= 16000000)
+    #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+#endif
+
+#if defined(JSON_HEDLEY_TI_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_VERSION)
+    #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL2000_VERSION)
+    #undef JSON_HEDLEY_TI_CL2000_VERSION
+#endif
+#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__)
+    #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_CL2000_VERSION)
+    #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL430_VERSION)
+    #undef JSON_HEDLEY_TI_CL430_VERSION
+#endif
+#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__)
+    #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_CL430_VERSION)
+    #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_ARMCL_VERSION)
+    #undef JSON_HEDLEY_TI_ARMCL_VERSION
+#endif
+#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__))
+    #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+
+#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_ARMCL_VERSION)
+    #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL6X_VERSION)
+    #undef JSON_HEDLEY_TI_CL6X_VERSION
+#endif
+#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__)
+    #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_CL6X_VERSION)
+    #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL7X_VERSION)
+    #undef JSON_HEDLEY_TI_CL7X_VERSION
+#endif
+#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__)
+    #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+
+#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_CL7X_VERSION)
+    #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TI_CLPRU_VERSION)
+    #undef JSON_HEDLEY_TI_CLPRU_VERSION
+#endif
+#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__)
+    #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000))
+#endif
+
+#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK)
+    #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TI_CLPRU_VERSION)
+    #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_CRAY_VERSION)
+    #undef JSON_HEDLEY_CRAY_VERSION
+#endif
+#if defined(_CRAYC)
+    #if defined(_RELEASE_PATCHLEVEL)
+        #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL)
+    #else
+        #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0)
+    #endif
+#endif
+
+#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK)
+    #undef JSON_HEDLEY_CRAY_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_CRAY_VERSION)
+    #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_IAR_VERSION)
+    #undef JSON_HEDLEY_IAR_VERSION
+#endif
+#if defined(__IAR_SYSTEMS_ICC__)
+    #if __VER__ > 1000
+        #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000))
+    #else
+        #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(__VER__ / 100, __VER__ % 100, 0)
+    #endif
+#endif
+
+#if defined(JSON_HEDLEY_IAR_VERSION_CHECK)
+    #undef JSON_HEDLEY_IAR_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_IAR_VERSION)
+    #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_TINYC_VERSION)
+    #undef JSON_HEDLEY_TINYC_VERSION
+#endif
+#if defined(__TINYC__)
+    #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100)
+#endif
+
+#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK)
+    #undef JSON_HEDLEY_TINYC_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_TINYC_VERSION)
+    #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_DMC_VERSION)
+    #undef JSON_HEDLEY_DMC_VERSION
+#endif
+#if defined(__DMC__)
+    #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf)
+#endif
+
+#if defined(JSON_HEDLEY_DMC_VERSION_CHECK)
+    #undef JSON_HEDLEY_DMC_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_DMC_VERSION)
+    #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_COMPCERT_VERSION)
+    #undef JSON_HEDLEY_COMPCERT_VERSION
+#endif
+#if defined(__COMPCERT_VERSION__)
+    #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100)
+#endif
+
+#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK)
+    #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_COMPCERT_VERSION)
+    #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_PELLES_VERSION)
+    #undef JSON_HEDLEY_PELLES_VERSION
+#endif
+#if defined(__POCC__)
+    #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0)
+#endif
+
+#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK)
+    #undef JSON_HEDLEY_PELLES_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_PELLES_VERSION)
+    #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_MCST_LCC_VERSION)
+    #undef JSON_HEDLEY_MCST_LCC_VERSION
+#endif
+#if defined(__LCC__) && defined(__LCC_MINOR__)
+    #define JSON_HEDLEY_MCST_LCC_VERSION JSON_HEDLEY_VERSION_ENCODE(__LCC__ / 100, __LCC__ % 100, __LCC_MINOR__)
+#endif
+
+#if defined(JSON_HEDLEY_MCST_LCC_VERSION_CHECK)
+    #undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_MCST_LCC_VERSION)
+    #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_MCST_LCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_VERSION)
+    #undef JSON_HEDLEY_GCC_VERSION
+#endif
+#if \
+    defined(JSON_HEDLEY_GNUC_VERSION) && \
+    !defined(__clang__) && \
+    !defined(JSON_HEDLEY_INTEL_VERSION) && \
+    !defined(JSON_HEDLEY_PGI_VERSION) && \
+    !defined(JSON_HEDLEY_ARM_VERSION) && \
+    !defined(JSON_HEDLEY_CRAY_VERSION) && \
+    !defined(JSON_HEDLEY_TI_VERSION) && \
+    !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \
+    !defined(JSON_HEDLEY_TI_CL430_VERSION) && \
+    !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \
+    !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \
+    !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \
+    !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \
+    !defined(__COMPCERT__) && \
+    !defined(JSON_HEDLEY_MCST_LCC_VERSION)
+    #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION
+#endif
+
+#if defined(JSON_HEDLEY_GCC_VERSION_CHECK)
+    #undef JSON_HEDLEY_GCC_VERSION_CHECK
+#endif
+#if defined(JSON_HEDLEY_GCC_VERSION)
+    #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch))
+#else
+    #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_ATTRIBUTE)
+    #undef JSON_HEDLEY_HAS_ATTRIBUTE
+#endif
+#if \
+  defined(__has_attribute) && \
+  ( \
+    (!defined(JSON_HEDLEY_IAR_VERSION) || JSON_HEDLEY_IAR_VERSION_CHECK(8,5,9)) \
+  )
+#  define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute)
+#else
+#  define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE)
+    #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE
+#endif
+#if defined(__has_attribute)
+    #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE)
+    #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE
+#endif
+#if defined(__has_attribute)
+    #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute)
+#else
+    #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE)
+    #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE
+#endif
+#if \
+    defined(__has_cpp_attribute) && \
+    defined(__cplusplus) && \
+    (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0))
+    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute)
+#else
+    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS)
+    #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS
+#endif
+#if !defined(__cplusplus) || !defined(__has_cpp_attribute)
+    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0)
+#elif \
+    !defined(JSON_HEDLEY_PGI_VERSION) && \
+    !defined(JSON_HEDLEY_IAR_VERSION) && \
+    (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \
+    (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0))
+    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute)
+#else
+    #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE)
+    #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE
+#endif
+#if defined(__has_cpp_attribute) && defined(__cplusplus)
+    #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE)
+    #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE
+#endif
+#if defined(__has_cpp_attribute) && defined(__cplusplus)
+    #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute)
+#else
+    #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_BUILTIN)
+    #undef JSON_HEDLEY_HAS_BUILTIN
+#endif
+#if defined(__has_builtin)
+    #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin)
+#else
+    #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN)
+    #undef JSON_HEDLEY_GNUC_HAS_BUILTIN
+#endif
+#if defined(__has_builtin)
+    #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN)
+    #undef JSON_HEDLEY_GCC_HAS_BUILTIN
+#endif
+#if defined(__has_builtin)
+    #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin)
+#else
+    #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_FEATURE)
+    #undef JSON_HEDLEY_HAS_FEATURE
+#endif
+#if defined(__has_feature)
+    #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature)
+#else
+    #define JSON_HEDLEY_HAS_FEATURE(feature) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE)
+    #undef JSON_HEDLEY_GNUC_HAS_FEATURE
+#endif
+#if defined(__has_feature)
+    #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_FEATURE)
+    #undef JSON_HEDLEY_GCC_HAS_FEATURE
+#endif
+#if defined(__has_feature)
+    #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature)
+#else
+    #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_EXTENSION)
+    #undef JSON_HEDLEY_HAS_EXTENSION
+#endif
+#if defined(__has_extension)
+    #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension)
+#else
+    #define JSON_HEDLEY_HAS_EXTENSION(extension) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION)
+    #undef JSON_HEDLEY_GNUC_HAS_EXTENSION
+#endif
+#if defined(__has_extension)
+    #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION)
+    #undef JSON_HEDLEY_GCC_HAS_EXTENSION
+#endif
+#if defined(__has_extension)
+    #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension)
+#else
+    #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE)
+    #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE
+#endif
+#if defined(__has_declspec_attribute)
+    #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute)
+#else
+    #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE)
+    #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE
+#endif
+#if defined(__has_declspec_attribute)
+    #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE)
+    #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE
+#endif
+#if defined(__has_declspec_attribute)
+    #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute)
+#else
+    #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_HAS_WARNING)
+    #undef JSON_HEDLEY_HAS_WARNING
+#endif
+#if defined(__has_warning)
+    #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning)
+#else
+    #define JSON_HEDLEY_HAS_WARNING(warning) (0)
+#endif
+
+#if defined(JSON_HEDLEY_GNUC_HAS_WARNING)
+    #undef JSON_HEDLEY_GNUC_HAS_WARNING
+#endif
+#if defined(__has_warning)
+    #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning)
+#else
+    #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_GCC_HAS_WARNING)
+    #undef JSON_HEDLEY_GCC_HAS_WARNING
+#endif
+#if defined(__has_warning)
+    #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning)
+#else
+    #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if \
+    (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \
+    defined(__clang__) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \
+    JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \
+    JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \
+    (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR))
+    #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value)
+#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0)
+    #define JSON_HEDLEY_PRAGMA(value) __pragma(value)
+#else
+    #define JSON_HEDLEY_PRAGMA(value)
+#endif
+
+#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH)
+    #undef JSON_HEDLEY_DIAGNOSTIC_PUSH
+#endif
+#if defined(JSON_HEDLEY_DIAGNOSTIC_POP)
+    #undef JSON_HEDLEY_DIAGNOSTIC_POP
+#endif
+#if defined(__clang__)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push")
+    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop")
+#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)")
+    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)")
+#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push")
+    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop")
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push))
+    #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop))
+#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push")
+    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop")
+#elif \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push")
+    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop")
+#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)")
+    #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)")
+#else
+    #define JSON_HEDLEY_DIAGNOSTIC_PUSH
+    #define JSON_HEDLEY_DIAGNOSTIC_POP
+#endif
+
+/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for
+   HEDLEY INTERNAL USE ONLY.  API subject to change without notice. */
+#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_)
+    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_
+#endif
+#if defined(__cplusplus)
+#  if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat")
+#    if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions")
+#      if JSON_HEDLEY_HAS_WARNING("-Wc++1z-extensions")
+#        define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \
+    _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \
+    _Pragma("clang diagnostic ignored \"-Wc++1z-extensions\"") \
+    xpr \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#      else
+#        define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \
+    _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \
+    xpr \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#      endif
+#    else
+#      define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \
+    xpr \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#    endif
+#  endif
+#endif
+#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x
+#endif
+
+#if defined(JSON_HEDLEY_CONST_CAST)
+    #undef JSON_HEDLEY_CONST_CAST
+#endif
+#if defined(__cplusplus)
+#  define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr))
+#elif \
+  JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \
+  JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \
+  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+#  define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \
+        JSON_HEDLEY_DIAGNOSTIC_PUSH \
+        JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \
+        ((T) (expr)); \
+        JSON_HEDLEY_DIAGNOSTIC_POP \
+    }))
+#else
+#  define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr))
+#endif
+
+#if defined(JSON_HEDLEY_REINTERPRET_CAST)
+    #undef JSON_HEDLEY_REINTERPRET_CAST
+#endif
+#if defined(__cplusplus)
+    #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr))
+#else
+    #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr))
+#endif
+
+#if defined(JSON_HEDLEY_STATIC_CAST)
+    #undef JSON_HEDLEY_STATIC_CAST
+#endif
+#if defined(__cplusplus)
+    #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr))
+#else
+    #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr))
+#endif
+
+#if defined(JSON_HEDLEY_CPP_CAST)
+    #undef JSON_HEDLEY_CPP_CAST
+#endif
+#if defined(__cplusplus)
+#  if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast")
+#    define JSON_HEDLEY_CPP_CAST(T, expr) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \
+    ((T) (expr)) \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#  elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0)
+#    define JSON_HEDLEY_CPP_CAST(T, expr) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("diag_suppress=Pe137") \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#  else
+#    define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr))
+#  endif
+#else
+#  define JSON_HEDLEY_CPP_CAST(T, expr) (expr)
+#endif
+
+#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED)
+    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations")
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"")
+#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)")
+#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786))
+#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445")
+#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444")
+#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"")
+#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996))
+#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444")
+#elif \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718")
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)")
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)")
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215")
+#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)")
+#else
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED
+#endif
+
+#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS)
+    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas")
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"")
+#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)")
+#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161))
+#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675")
+#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"")
+#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068))
+#elif \
+    JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163")
+#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163")
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161")
+#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161")
+#else
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS
+#endif
+
+#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES)
+    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes")
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"")
+#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"")
+#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)")
+#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292))
+#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030))
+#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098")
+#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097")
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)")
+#elif \
+    JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173")
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097")
+#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097")
+#else
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES
+#endif
+
+#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL)
+    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wcast-qual")
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"")
+#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)")
+#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"")
+#else
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL
+#endif
+
+#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION)
+    #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wunused-function")
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"")
+#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"")
+#elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505))
+#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142")
+#else
+    #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION
+#endif
+
+#if defined(JSON_HEDLEY_DEPRECATED)
+    #undef JSON_HEDLEY_DEPRECATED
+#endif
+#if defined(JSON_HEDLEY_DEPRECATED_FOR)
+    #undef JSON_HEDLEY_DEPRECATED_FOR
+#endif
+#if \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since))
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement))
+#elif \
+    (JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) && !defined(JSON_HEDLEY_IAR_VERSION)) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \
+    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since)))
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement)))
+#elif defined(__cplusplus) && (__cplusplus >= 201402L)
+    #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]])
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]])
+#elif \
+    JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \
+    JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)
+    #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__))
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__))
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \
+    JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated)
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated)
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated")
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated")
+#else
+    #define JSON_HEDLEY_DEPRECATED(since)
+    #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement)
+#endif
+
+#if defined(JSON_HEDLEY_UNAVAILABLE)
+    #undef JSON_HEDLEY_UNAVAILABLE
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since)))
+#else
+    #define JSON_HEDLEY_UNAVAILABLE(available_since)
+#endif
+
+#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT)
+    #undef JSON_HEDLEY_WARN_UNUSED_RESULT
+#endif
+#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG)
+    #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \
+    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__))
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__))
+#elif (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L)
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]])
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]])
+#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard)
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]])
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]])
+#elif defined(_Check_return_) /* SAL */
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_
+#else
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT
+    #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg)
+#endif
+
+#if defined(JSON_HEDLEY_SENTINEL)
+    #undef JSON_HEDLEY_SENTINEL
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position)))
+#else
+    #define JSON_HEDLEY_SENTINEL(position)
+#endif
+
+#if defined(JSON_HEDLEY_NO_RETURN)
+    #undef JSON_HEDLEY_NO_RETURN
+#endif
+#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_NO_RETURN __noreturn
+#elif \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__))
+#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L
+    #define JSON_HEDLEY_NO_RETURN _Noreturn
+#elif defined(__cplusplus) && (__cplusplus >= 201103L)
+    #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]])
+#elif \
+    JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)
+    #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__))
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)
+    #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return")
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_NO_RETURN __declspec(noreturn)
+#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus)
+    #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;")
+#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0)
+    #define JSON_HEDLEY_NO_RETURN __attribute((noreturn))
+#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0)
+    #define JSON_HEDLEY_NO_RETURN __declspec(noreturn)
+#else
+    #define JSON_HEDLEY_NO_RETURN
+#endif
+
+#if defined(JSON_HEDLEY_NO_ESCAPE)
+    #undef JSON_HEDLEY_NO_ESCAPE
+#endif
+#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape)
+    #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__))
+#else
+    #define JSON_HEDLEY_NO_ESCAPE
+#endif
+
+#if defined(JSON_HEDLEY_UNREACHABLE)
+    #undef JSON_HEDLEY_UNREACHABLE
+#endif
+#if defined(JSON_HEDLEY_UNREACHABLE_RETURN)
+    #undef JSON_HEDLEY_UNREACHABLE_RETURN
+#endif
+#if defined(JSON_HEDLEY_ASSUME)
+    #undef JSON_HEDLEY_ASSUME
+#endif
+#if \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_ASSUME(expr) __assume(expr)
+#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume)
+    #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr)
+#elif \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0)
+    #if defined(__cplusplus)
+        #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr)
+    #else
+        #define JSON_HEDLEY_ASSUME(expr) _nassert(expr)
+    #endif
+#endif
+#if \
+    (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \
+    JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) || \
+    JSON_HEDLEY_CRAY_VERSION_CHECK(10,0,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable()
+#elif defined(JSON_HEDLEY_ASSUME)
+    #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0)
+#endif
+#if !defined(JSON_HEDLEY_ASSUME)
+    #if defined(JSON_HEDLEY_UNREACHABLE)
+        #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1)))
+    #else
+        #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr)
+    #endif
+#endif
+#if defined(JSON_HEDLEY_UNREACHABLE)
+    #if  \
+        JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \
+        JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0)
+        #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value))
+    #else
+        #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE()
+    #endif
+#else
+    #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value)
+#endif
+#if !defined(JSON_HEDLEY_UNREACHABLE)
+    #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0)
+#endif
+
+JSON_HEDLEY_DIAGNOSTIC_PUSH
+#if JSON_HEDLEY_HAS_WARNING("-Wpedantic")
+    #pragma clang diagnostic ignored "-Wpedantic"
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus)
+    #pragma clang diagnostic ignored "-Wc++98-compat-pedantic"
+#endif
+#if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0)
+    #if defined(__clang__)
+        #pragma clang diagnostic ignored "-Wvariadic-macros"
+    #elif defined(JSON_HEDLEY_GCC_VERSION)
+        #pragma GCC diagnostic ignored "-Wvariadic-macros"
+    #endif
+#endif
+#if defined(JSON_HEDLEY_NON_NULL)
+    #undef JSON_HEDLEY_NON_NULL
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0)
+    #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__)))
+#else
+    #define JSON_HEDLEY_NON_NULL(...)
+#endif
+JSON_HEDLEY_DIAGNOSTIC_POP
+
+#if defined(JSON_HEDLEY_PRINTF_FORMAT)
+    #undef JSON_HEDLEY_PRINTF_FORMAT
+#endif
+#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO)
+    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check)))
+#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO)
+    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check)))
+#elif \
+    JSON_HEDLEY_HAS_ATTRIBUTE(format) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check)))
+#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0)
+    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check))
+#else
+    #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check)
+#endif
+
+#if defined(JSON_HEDLEY_CONSTEXPR)
+    #undef JSON_HEDLEY_CONSTEXPR
+#endif
+#if defined(__cplusplus)
+    #if __cplusplus >= 201103L
+        #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr)
+    #endif
+#endif
+#if !defined(JSON_HEDLEY_CONSTEXPR)
+    #define JSON_HEDLEY_CONSTEXPR
+#endif
+
+#if defined(JSON_HEDLEY_PREDICT)
+    #undef JSON_HEDLEY_PREDICT
+#endif
+#if defined(JSON_HEDLEY_LIKELY)
+    #undef JSON_HEDLEY_LIKELY
+#endif
+#if defined(JSON_HEDLEY_UNLIKELY)
+    #undef JSON_HEDLEY_UNLIKELY
+#endif
+#if defined(JSON_HEDLEY_UNPREDICTABLE)
+    #undef JSON_HEDLEY_UNPREDICTABLE
+#endif
+#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable)
+    #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr))
+#endif
+#if \
+  (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) && !defined(JSON_HEDLEY_PGI_VERSION)) || \
+  JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) || \
+  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+#  define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability(  (expr), (value), (probability))
+#  define JSON_HEDLEY_PREDICT_TRUE(expr, probability)   __builtin_expect_with_probability(!!(expr),    1   , (probability))
+#  define JSON_HEDLEY_PREDICT_FALSE(expr, probability)  __builtin_expect_with_probability(!!(expr),    0   , (probability))
+#  define JSON_HEDLEY_LIKELY(expr)                      __builtin_expect                 (!!(expr),    1                  )
+#  define JSON_HEDLEY_UNLIKELY(expr)                    __builtin_expect                 (!!(expr),    0                  )
+#elif \
+  (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \
+  JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \
+  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+  (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \
+  JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+  JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+  JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+  JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \
+  JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \
+  JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \
+  JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \
+  JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+  JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+  JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \
+  JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \
+  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+#  define JSON_HEDLEY_PREDICT(expr, expected, probability) \
+    (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)))
+#  define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \
+    (__extension__ ({ \
+        double hedley_probability_ = (probability); \
+        ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \
+    }))
+#  define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \
+    (__extension__ ({ \
+        double hedley_probability_ = (probability); \
+        ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \
+    }))
+#  define JSON_HEDLEY_LIKELY(expr)   __builtin_expect(!!(expr), 1)
+#  define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0)
+#else
+#  define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))
+#  define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr))
+#  define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr))
+#  define JSON_HEDLEY_LIKELY(expr) (!!(expr))
+#  define JSON_HEDLEY_UNLIKELY(expr) (!!(expr))
+#endif
+#if !defined(JSON_HEDLEY_UNPREDICTABLE)
+    #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5)
+#endif
+
+#if defined(JSON_HEDLEY_MALLOC)
+    #undef JSON_HEDLEY_MALLOC
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_MALLOC __attribute__((__malloc__))
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)
+    #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory")
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_MALLOC __declspec(restrict)
+#else
+    #define JSON_HEDLEY_MALLOC
+#endif
+
+#if defined(JSON_HEDLEY_PURE)
+    #undef JSON_HEDLEY_PURE
+#endif
+#if \
+  JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \
+  JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \
+  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+  JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+  JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+  JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+  JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+  (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+  (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+  (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+  (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+  JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+  JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+  JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \
+  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+#  define JSON_HEDLEY_PURE __attribute__((__pure__))
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)
+#  define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data")
+#elif defined(__cplusplus) && \
+    ( \
+      JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \
+      JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \
+      JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \
+    )
+#  define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;")
+#else
+#  define JSON_HEDLEY_PURE
+#endif
+
+#if defined(JSON_HEDLEY_CONST)
+    #undef JSON_HEDLEY_CONST
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(const) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_CONST __attribute__((__const__))
+#elif \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0)
+    #define JSON_HEDLEY_CONST _Pragma("no_side_effect")
+#else
+    #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE
+#endif
+
+#if defined(JSON_HEDLEY_RESTRICT)
+    #undef JSON_HEDLEY_RESTRICT
+#endif
+#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus)
+    #define JSON_HEDLEY_RESTRICT restrict
+#elif \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+    JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \
+    JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \
+    defined(__clang__) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_RESTRICT __restrict
+#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus)
+    #define JSON_HEDLEY_RESTRICT _Restrict
+#else
+    #define JSON_HEDLEY_RESTRICT
+#endif
+
+#if defined(JSON_HEDLEY_INLINE)
+    #undef JSON_HEDLEY_INLINE
+#endif
+#if \
+    (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \
+    (defined(__cplusplus) && (__cplusplus >= 199711L))
+    #define JSON_HEDLEY_INLINE inline
+#elif \
+    defined(JSON_HEDLEY_GCC_VERSION) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0)
+    #define JSON_HEDLEY_INLINE __inline__
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_INLINE __inline
+#else
+    #define JSON_HEDLEY_INLINE
+#endif
+
+#if defined(JSON_HEDLEY_ALWAYS_INLINE)
+    #undef JSON_HEDLEY_ALWAYS_INLINE
+#endif
+#if \
+  JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \
+  JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \
+  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+  JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+  JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+  JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+  JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+  (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+  (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+  (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+  (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+  JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+  JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+  JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+  JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \
+  JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)
+#  define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE
+#elif \
+  JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \
+  JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+#  define JSON_HEDLEY_ALWAYS_INLINE __forceinline
+#elif defined(__cplusplus) && \
+    ( \
+      JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+      JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+      JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+      JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \
+      JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+      JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \
+    )
+#  define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;")
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+#  define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced")
+#else
+#  define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE
+#endif
+
+#if defined(JSON_HEDLEY_NEVER_INLINE)
+    #undef JSON_HEDLEY_NEVER_INLINE
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \
+    JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \
+    (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \
+    (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \
+    (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \
+    (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \
+    JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \
+    JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \
+    JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0)
+    #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__))
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline)
+#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0)
+    #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline")
+#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus)
+    #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;")
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+    #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never")
+#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0)
+    #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline))
+#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0)
+    #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline)
+#else
+    #define JSON_HEDLEY_NEVER_INLINE
+#endif
+
+#if defined(JSON_HEDLEY_PRIVATE)
+    #undef JSON_HEDLEY_PRIVATE
+#endif
+#if defined(JSON_HEDLEY_PUBLIC)
+    #undef JSON_HEDLEY_PUBLIC
+#endif
+#if defined(JSON_HEDLEY_IMPORT)
+    #undef JSON_HEDLEY_IMPORT
+#endif
+#if defined(_WIN32) || defined(__CYGWIN__)
+#  define JSON_HEDLEY_PRIVATE
+#  define JSON_HEDLEY_PUBLIC   __declspec(dllexport)
+#  define JSON_HEDLEY_IMPORT   __declspec(dllimport)
+#else
+#  if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \
+    JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \
+    ( \
+      defined(__TI_EABI__) && \
+      ( \
+        (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \
+        JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \
+      ) \
+    ) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+#    define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden")))
+#    define JSON_HEDLEY_PUBLIC  __attribute__((__visibility__("default")))
+#  else
+#    define JSON_HEDLEY_PRIVATE
+#    define JSON_HEDLEY_PUBLIC
+#  endif
+#  define JSON_HEDLEY_IMPORT    extern
+#endif
+
+#if defined(JSON_HEDLEY_NO_THROW)
+    #undef JSON_HEDLEY_NO_THROW
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__))
+#elif \
+    JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0)
+    #define JSON_HEDLEY_NO_THROW __declspec(nothrow)
+#else
+    #define JSON_HEDLEY_NO_THROW
+#endif
+
+#if defined(JSON_HEDLEY_FALL_THROUGH)
+    #undef JSON_HEDLEY_FALL_THROUGH
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__))
+#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough)
+    #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]])
+#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough)
+    #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]])
+#elif defined(__fallthrough) /* SAL */
+    #define JSON_HEDLEY_FALL_THROUGH __fallthrough
+#else
+    #define JSON_HEDLEY_FALL_THROUGH
+#endif
+
+#if defined(JSON_HEDLEY_RETURNS_NON_NULL)
+    #undef JSON_HEDLEY_RETURNS_NON_NULL
+#endif
+#if \
+    JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__))
+#elif defined(_Ret_notnull_) /* SAL */
+    #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_
+#else
+    #define JSON_HEDLEY_RETURNS_NON_NULL
+#endif
+
+#if defined(JSON_HEDLEY_ARRAY_PARAM)
+    #undef JSON_HEDLEY_ARRAY_PARAM
+#endif
+#if \
+    defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \
+    !defined(__STDC_NO_VLA__) && \
+    !defined(__cplusplus) && \
+    !defined(JSON_HEDLEY_PGI_VERSION) && \
+    !defined(JSON_HEDLEY_TINYC_VERSION)
+    #define JSON_HEDLEY_ARRAY_PARAM(name) (name)
+#else
+    #define JSON_HEDLEY_ARRAY_PARAM(name)
+#endif
+
+#if defined(JSON_HEDLEY_IS_CONSTANT)
+    #undef JSON_HEDLEY_IS_CONSTANT
+#endif
+#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR)
+    #undef JSON_HEDLEY_REQUIRE_CONSTEXPR
+#endif
+/* JSON_HEDLEY_IS_CONSTEXPR_ is for
+   HEDLEY INTERNAL USE ONLY.  API subject to change without notice. */
+#if defined(JSON_HEDLEY_IS_CONSTEXPR_)
+    #undef JSON_HEDLEY_IS_CONSTEXPR_
+#endif
+#if \
+    JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \
+    JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \
+    JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+    JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \
+    JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \
+    JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \
+    JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \
+    (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \
+    JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \
+    JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10)
+    #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr)
+#endif
+#if !defined(__cplusplus)
+#  if \
+       JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \
+       JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \
+       JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+       JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \
+       JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \
+       JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \
+       JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24)
+#if defined(__INTPTR_TYPE__)
+    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*)
+#else
+    #include 
+    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*)
+#endif
+#  elif \
+       ( \
+          defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \
+          !defined(JSON_HEDLEY_SUNPRO_VERSION) && \
+          !defined(JSON_HEDLEY_PGI_VERSION) && \
+          !defined(JSON_HEDLEY_IAR_VERSION)) || \
+       (JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) && !defined(JSON_HEDLEY_IAR_VERSION)) || \
+       JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \
+       JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \
+       JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \
+       JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0)
+#if defined(__INTPTR_TYPE__)
+    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0)
+#else
+    #include 
+    #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0)
+#endif
+#  elif \
+       defined(JSON_HEDLEY_GCC_VERSION) || \
+       defined(JSON_HEDLEY_INTEL_VERSION) || \
+       defined(JSON_HEDLEY_TINYC_VERSION) || \
+       defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \
+       JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \
+       defined(JSON_HEDLEY_TI_CL2000_VERSION) || \
+       defined(JSON_HEDLEY_TI_CL6X_VERSION) || \
+       defined(JSON_HEDLEY_TI_CL7X_VERSION) || \
+       defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \
+       defined(__clang__)
+#    define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \
+        sizeof(void) != \
+        sizeof(*( \
+                  1 ? \
+                  ((void*) ((expr) * 0L) ) : \
+((struct { char v[sizeof(void) * 2]; } *) 1) \
+                ) \
+              ) \
+                                            )
+#  endif
+#endif
+#if defined(JSON_HEDLEY_IS_CONSTEXPR_)
+    #if !defined(JSON_HEDLEY_IS_CONSTANT)
+        #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr)
+    #endif
+    #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1))
+#else
+    #if !defined(JSON_HEDLEY_IS_CONSTANT)
+        #define JSON_HEDLEY_IS_CONSTANT(expr) (0)
+    #endif
+    #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr)
+#endif
+
+#if defined(JSON_HEDLEY_BEGIN_C_DECLS)
+    #undef JSON_HEDLEY_BEGIN_C_DECLS
+#endif
+#if defined(JSON_HEDLEY_END_C_DECLS)
+    #undef JSON_HEDLEY_END_C_DECLS
+#endif
+#if defined(JSON_HEDLEY_C_DECL)
+    #undef JSON_HEDLEY_C_DECL
+#endif
+#if defined(__cplusplus)
+    #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" {
+    #define JSON_HEDLEY_END_C_DECLS }
+    #define JSON_HEDLEY_C_DECL extern "C"
+#else
+    #define JSON_HEDLEY_BEGIN_C_DECLS
+    #define JSON_HEDLEY_END_C_DECLS
+    #define JSON_HEDLEY_C_DECL
+#endif
+
+#if defined(JSON_HEDLEY_STATIC_ASSERT)
+    #undef JSON_HEDLEY_STATIC_ASSERT
+#endif
+#if \
+  !defined(__cplusplus) && ( \
+      (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \
+      (JSON_HEDLEY_HAS_FEATURE(c_static_assert) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \
+      JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \
+      JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \
+      defined(_Static_assert) \
+    )
+#  define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message)
+#elif \
+  (defined(__cplusplus) && (__cplusplus >= 201103L)) || \
+  JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) || \
+  JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+#  define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message))
+#else
+#  define JSON_HEDLEY_STATIC_ASSERT(expr, message)
+#endif
+
+#if defined(JSON_HEDLEY_NULL)
+    #undef JSON_HEDLEY_NULL
+#endif
+#if defined(__cplusplus)
+    #if __cplusplus >= 201103L
+        #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr)
+    #elif defined(NULL)
+        #define JSON_HEDLEY_NULL NULL
+    #else
+        #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0)
+    #endif
+#elif defined(NULL)
+    #define JSON_HEDLEY_NULL NULL
+#else
+    #define JSON_HEDLEY_NULL ((void*) 0)
+#endif
+
+#if defined(JSON_HEDLEY_MESSAGE)
+    #undef JSON_HEDLEY_MESSAGE
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas")
+#  define JSON_HEDLEY_MESSAGE(msg) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \
+    JSON_HEDLEY_PRAGMA(message msg) \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#elif \
+  JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \
+  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg)
+#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0)
+#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg)
+#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0)
+#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg))
+#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0)
+#  define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg))
+#else
+#  define JSON_HEDLEY_MESSAGE(msg)
+#endif
+
+#if defined(JSON_HEDLEY_WARNING)
+    #undef JSON_HEDLEY_WARNING
+#endif
+#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas")
+#  define JSON_HEDLEY_WARNING(msg) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \
+    JSON_HEDLEY_PRAGMA(clang warning msg) \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#elif \
+  JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \
+  JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \
+  JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0)
+#  define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg)
+#elif \
+  JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \
+  JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+#  define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg))
+#else
+#  define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg)
+#endif
+
+#if defined(JSON_HEDLEY_REQUIRE)
+    #undef JSON_HEDLEY_REQUIRE
+#endif
+#if defined(JSON_HEDLEY_REQUIRE_MSG)
+    #undef JSON_HEDLEY_REQUIRE_MSG
+#endif
+#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if)
+#  if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat")
+#    define JSON_HEDLEY_REQUIRE(expr) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \
+    __attribute__((diagnose_if(!(expr), #expr, "error"))) \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#    define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \
+    JSON_HEDLEY_DIAGNOSTIC_PUSH \
+    _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \
+    __attribute__((diagnose_if(!(expr), msg, "error"))) \
+    JSON_HEDLEY_DIAGNOSTIC_POP
+#  else
+#    define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error")))
+#    define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error")))
+#  endif
+#else
+#  define JSON_HEDLEY_REQUIRE(expr)
+#  define JSON_HEDLEY_REQUIRE_MSG(expr,msg)
+#endif
+
+#if defined(JSON_HEDLEY_FLAGS)
+    #undef JSON_HEDLEY_FLAGS
+#endif
+#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) && (!defined(__cplusplus) || JSON_HEDLEY_HAS_WARNING("-Wbitfield-enum-conversion"))
+    #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__))
+#else
+    #define JSON_HEDLEY_FLAGS
+#endif
+
+#if defined(JSON_HEDLEY_FLAGS_CAST)
+    #undef JSON_HEDLEY_FLAGS_CAST
+#endif
+#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0)
+#  define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \
+        JSON_HEDLEY_DIAGNOSTIC_PUSH \
+        _Pragma("warning(disable:188)") \
+        ((T) (expr)); \
+        JSON_HEDLEY_DIAGNOSTIC_POP \
+    }))
+#else
+#  define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr)
+#endif
+
+#if defined(JSON_HEDLEY_EMPTY_BASES)
+    #undef JSON_HEDLEY_EMPTY_BASES
+#endif
+#if \
+    (JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0)) || \
+    JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0)
+    #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases)
+#else
+    #define JSON_HEDLEY_EMPTY_BASES
+#endif
+
+/* Remaining macros are deprecated. */
+
+#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK)
+    #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK
+#endif
+#if defined(__clang__)
+    #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0)
+#else
+    #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch)
+#endif
+
+#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE)
+    #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE
+#endif
+#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute)
+
+#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE)
+    #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE
+#endif
+#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute)
+
+#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN)
+    #undef JSON_HEDLEY_CLANG_HAS_BUILTIN
+#endif
+#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin)
+
+#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE)
+    #undef JSON_HEDLEY_CLANG_HAS_FEATURE
+#endif
+#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature)
+
+#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION)
+    #undef JSON_HEDLEY_CLANG_HAS_EXTENSION
+#endif
+#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension)
+
+#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE)
+    #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE
+#endif
+#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute)
+
+#if defined(JSON_HEDLEY_CLANG_HAS_WARNING)
+    #undef JSON_HEDLEY_CLANG_HAS_WARNING
+#endif
+#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning)
+
+#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */
+
+
+// This file contains all internal macro definitions (except those affecting ABI)
+// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them
+
+// #include 
+
+
+// exclude unsupported compilers
+#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK)
+    #if defined(__clang__)
+        #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400
+            #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers"
+        #endif
+    #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER))
+        #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800
+            #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers"
+        #endif
+    #endif
+#endif
+
+// C++ language standard detection
+// if the user manually specified the used c++ version this is skipped
+#if !defined(JSON_HAS_CPP_20) && !defined(JSON_HAS_CPP_17) && !defined(JSON_HAS_CPP_14) && !defined(JSON_HAS_CPP_11)
+    #if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L)
+        #define JSON_HAS_CPP_20
+        #define JSON_HAS_CPP_17
+        #define JSON_HAS_CPP_14
+    #elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464
+        #define JSON_HAS_CPP_17
+        #define JSON_HAS_CPP_14
+    #elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1)
+        #define JSON_HAS_CPP_14
+    #endif
+    // the cpp 11 flag is always specified because it is the minimal required version
+    #define JSON_HAS_CPP_11
+#endif
+
+#ifdef __has_include
+    #if __has_include()
+        #include 
+    #endif
+#endif
+
+#if !defined(JSON_HAS_FILESYSTEM) && !defined(JSON_HAS_EXPERIMENTAL_FILESYSTEM)
+    #ifdef JSON_HAS_CPP_17
+        #if defined(__cpp_lib_filesystem)
+            #define JSON_HAS_FILESYSTEM 1
+        #elif defined(__cpp_lib_experimental_filesystem)
+            #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1
+        #elif !defined(__has_include)
+            #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1
+        #elif __has_include()
+            #define JSON_HAS_FILESYSTEM 1
+        #elif __has_include()
+            #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1
+        #endif
+
+        // std::filesystem does not work on MinGW GCC 8: https://sourceforge.net/p/mingw-w64/bugs/737/
+        #if defined(__MINGW32__) && defined(__GNUC__) && __GNUC__ == 8
+            #undef JSON_HAS_FILESYSTEM
+            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+        #endif
+
+        // no filesystem support before GCC 8: https://en.cppreference.com/w/cpp/compiler_support
+        #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 8
+            #undef JSON_HAS_FILESYSTEM
+            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+        #endif
+
+        // no filesystem support before Clang 7: https://en.cppreference.com/w/cpp/compiler_support
+        #if defined(__clang_major__) && __clang_major__ < 7
+            #undef JSON_HAS_FILESYSTEM
+            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+        #endif
+
+        // no filesystem support before MSVC 19.14: https://en.cppreference.com/w/cpp/compiler_support
+        #if defined(_MSC_VER) && _MSC_VER < 1914
+            #undef JSON_HAS_FILESYSTEM
+            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+        #endif
+
+        // no filesystem support before iOS 13
+        #if defined(__IPHONE_OS_VERSION_MIN_REQUIRED) && __IPHONE_OS_VERSION_MIN_REQUIRED < 130000
+            #undef JSON_HAS_FILESYSTEM
+            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+        #endif
+
+        // no filesystem support before macOS Catalina
+        #if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500
+            #undef JSON_HAS_FILESYSTEM
+            #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+        #endif
+    #endif
+#endif
+
+#ifndef JSON_HAS_EXPERIMENTAL_FILESYSTEM
+    #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 0
+#endif
+
+#ifndef JSON_HAS_FILESYSTEM
+    #define JSON_HAS_FILESYSTEM 0
+#endif
+
+#ifndef JSON_HAS_THREE_WAY_COMPARISON
+    #if defined(__cpp_impl_three_way_comparison) && __cpp_impl_three_way_comparison >= 201907L \
+        && defined(__cpp_lib_three_way_comparison) && __cpp_lib_three_way_comparison >= 201907L
+        #define JSON_HAS_THREE_WAY_COMPARISON 1
+    #else
+        #define JSON_HAS_THREE_WAY_COMPARISON 0
+    #endif
+#endif
+
+#ifndef JSON_HAS_RANGES
+    // ranges header shipping in GCC 11.1.0 (released 2021-04-27) has syntax error
+    #if defined(__GLIBCXX__) && __GLIBCXX__ == 20210427
+        #define JSON_HAS_RANGES 0
+    #elif defined(__cpp_lib_ranges)
+        #define JSON_HAS_RANGES 1
+    #else
+        #define JSON_HAS_RANGES 0
+    #endif
+#endif
+
+#ifndef JSON_HAS_STATIC_RTTI
+    #if !defined(_HAS_STATIC_RTTI) || _HAS_STATIC_RTTI != 0
+        #define JSON_HAS_STATIC_RTTI 1
+    #else
+        #define JSON_HAS_STATIC_RTTI 0
+    #endif
+#endif
+
+#ifdef JSON_HAS_CPP_17
+    #define JSON_INLINE_VARIABLE inline
+#else
+    #define JSON_INLINE_VARIABLE
+#endif
+
+#if JSON_HEDLEY_HAS_ATTRIBUTE(no_unique_address)
+    #define JSON_NO_UNIQUE_ADDRESS [[no_unique_address]]
+#else
+    #define JSON_NO_UNIQUE_ADDRESS
+#endif
+
+// disable documentation warnings on clang
+#if defined(__clang__)
+    #pragma clang diagnostic push
+    #pragma clang diagnostic ignored "-Wdocumentation"
+    #pragma clang diagnostic ignored "-Wdocumentation-unknown-command"
+#endif
+
+// allow disabling exceptions
+#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION)
+    #define JSON_THROW(exception) throw exception
+    #define JSON_TRY try
+    #define JSON_CATCH(exception) catch(exception)
+    #define JSON_INTERNAL_CATCH(exception) catch(exception)
+#else
+    #include 
+    #define JSON_THROW(exception) std::abort()
+    #define JSON_TRY if(true)
+    #define JSON_CATCH(exception) if(false)
+    #define JSON_INTERNAL_CATCH(exception) if(false)
+#endif
+
+// override exception macros
+#if defined(JSON_THROW_USER)
+    #undef JSON_THROW
+    #define JSON_THROW JSON_THROW_USER
+#endif
+#if defined(JSON_TRY_USER)
+    #undef JSON_TRY
+    #define JSON_TRY JSON_TRY_USER
+#endif
+#if defined(JSON_CATCH_USER)
+    #undef JSON_CATCH
+    #define JSON_CATCH JSON_CATCH_USER
+    #undef JSON_INTERNAL_CATCH
+    #define JSON_INTERNAL_CATCH JSON_CATCH_USER
+#endif
+#if defined(JSON_INTERNAL_CATCH_USER)
+    #undef JSON_INTERNAL_CATCH
+    #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER
+#endif
+
+// allow overriding assert
+#if !defined(JSON_ASSERT)
+    #include  // assert
+    #define JSON_ASSERT(x) assert(x)
+#endif
+
+// allow to access some private functions (needed by the test suite)
+#if defined(JSON_TESTS_PRIVATE)
+    #define JSON_PRIVATE_UNLESS_TESTED public
+#else
+    #define JSON_PRIVATE_UNLESS_TESTED private
+#endif
+
+/*!
+@brief macro to briefly define a mapping between an enum and JSON
+@def NLOHMANN_JSON_SERIALIZE_ENUM
+@since version 3.4.0
+*/
+#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...)                                            \
+    template                                                            \
+    inline void to_json(BasicJsonType& j, const ENUM_TYPE& e)                                   \
+    {                                                                                           \
+        static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!");          \
+        static const std::pair m[] = __VA_ARGS__;                     \
+        auto it = std::find_if(std::begin(m), std::end(m),                                      \
+                               [e](const std::pair& ej_pair) -> bool  \
+        {                                                                                       \
+            return ej_pair.first == e;                                                          \
+        });                                                                                     \
+        j = ((it != std::end(m)) ? it : std::begin(m))->second;                                 \
+    }                                                                                           \
+    template                                                            \
+    inline void from_json(const BasicJsonType& j, ENUM_TYPE& e)                                 \
+    {                                                                                           \
+        static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!");          \
+        static const std::pair m[] = __VA_ARGS__;                     \
+        auto it = std::find_if(std::begin(m), std::end(m),                                      \
+                               [&j](const std::pair& ej_pair) -> bool \
+        {                                                                                       \
+            return ej_pair.second == j;                                                         \
+        });                                                                                     \
+        e = ((it != std::end(m)) ? it : std::begin(m))->first;                                  \
+    }
+
+// Ugly macros to avoid uglier copy-paste when specializing basic_json. They
+// may be removed in the future once the class is split.
+
+#define NLOHMANN_BASIC_JSON_TPL_DECLARATION                                \
+    template class ObjectType,   \
+             template class ArrayType,              \
+             class StringType, class BooleanType, class NumberIntegerType, \
+             class NumberUnsignedType, class NumberFloatType,              \
+             template class AllocatorType,                       \
+             template class JSONSerializer,     \
+             class BinaryType,                                             \
+             class CustomBaseClass>
+
+#define NLOHMANN_BASIC_JSON_TPL                                            \
+    basic_json
+
+// Macros to simplify conversion from/to types
+
+#define NLOHMANN_JSON_EXPAND( x ) x
+#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME
+#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \
+        NLOHMANN_JSON_PASTE64, \
+        NLOHMANN_JSON_PASTE63, \
+        NLOHMANN_JSON_PASTE62, \
+        NLOHMANN_JSON_PASTE61, \
+        NLOHMANN_JSON_PASTE60, \
+        NLOHMANN_JSON_PASTE59, \
+        NLOHMANN_JSON_PASTE58, \
+        NLOHMANN_JSON_PASTE57, \
+        NLOHMANN_JSON_PASTE56, \
+        NLOHMANN_JSON_PASTE55, \
+        NLOHMANN_JSON_PASTE54, \
+        NLOHMANN_JSON_PASTE53, \
+        NLOHMANN_JSON_PASTE52, \
+        NLOHMANN_JSON_PASTE51, \
+        NLOHMANN_JSON_PASTE50, \
+        NLOHMANN_JSON_PASTE49, \
+        NLOHMANN_JSON_PASTE48, \
+        NLOHMANN_JSON_PASTE47, \
+        NLOHMANN_JSON_PASTE46, \
+        NLOHMANN_JSON_PASTE45, \
+        NLOHMANN_JSON_PASTE44, \
+        NLOHMANN_JSON_PASTE43, \
+        NLOHMANN_JSON_PASTE42, \
+        NLOHMANN_JSON_PASTE41, \
+        NLOHMANN_JSON_PASTE40, \
+        NLOHMANN_JSON_PASTE39, \
+        NLOHMANN_JSON_PASTE38, \
+        NLOHMANN_JSON_PASTE37, \
+        NLOHMANN_JSON_PASTE36, \
+        NLOHMANN_JSON_PASTE35, \
+        NLOHMANN_JSON_PASTE34, \
+        NLOHMANN_JSON_PASTE33, \
+        NLOHMANN_JSON_PASTE32, \
+        NLOHMANN_JSON_PASTE31, \
+        NLOHMANN_JSON_PASTE30, \
+        NLOHMANN_JSON_PASTE29, \
+        NLOHMANN_JSON_PASTE28, \
+        NLOHMANN_JSON_PASTE27, \
+        NLOHMANN_JSON_PASTE26, \
+        NLOHMANN_JSON_PASTE25, \
+        NLOHMANN_JSON_PASTE24, \
+        NLOHMANN_JSON_PASTE23, \
+        NLOHMANN_JSON_PASTE22, \
+        NLOHMANN_JSON_PASTE21, \
+        NLOHMANN_JSON_PASTE20, \
+        NLOHMANN_JSON_PASTE19, \
+        NLOHMANN_JSON_PASTE18, \
+        NLOHMANN_JSON_PASTE17, \
+        NLOHMANN_JSON_PASTE16, \
+        NLOHMANN_JSON_PASTE15, \
+        NLOHMANN_JSON_PASTE14, \
+        NLOHMANN_JSON_PASTE13, \
+        NLOHMANN_JSON_PASTE12, \
+        NLOHMANN_JSON_PASTE11, \
+        NLOHMANN_JSON_PASTE10, \
+        NLOHMANN_JSON_PASTE9, \
+        NLOHMANN_JSON_PASTE8, \
+        NLOHMANN_JSON_PASTE7, \
+        NLOHMANN_JSON_PASTE6, \
+        NLOHMANN_JSON_PASTE5, \
+        NLOHMANN_JSON_PASTE4, \
+        NLOHMANN_JSON_PASTE3, \
+        NLOHMANN_JSON_PASTE2, \
+        NLOHMANN_JSON_PASTE1)(__VA_ARGS__))
+#define NLOHMANN_JSON_PASTE2(func, v1) func(v1)
+#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2)
+#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3)
+#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4)
+#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5)
+#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6)
+#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7)
+#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8)
+#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9)
+#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10)
+#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)
+#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12)
+#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13)
+#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14)
+#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15)
+#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16)
+#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17)
+#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18)
+#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19)
+#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20)
+#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21)
+#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22)
+#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23)
+#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24)
+#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25)
+#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26)
+#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27)
+#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28)
+#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29)
+#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30)
+#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31)
+#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32)
+#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33)
+#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34)
+#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35)
+#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36)
+#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37)
+#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38)
+#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39)
+#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40)
+#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41)
+#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42)
+#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43)
+#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44)
+#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45)
+#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46)
+#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47)
+#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48)
+#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49)
+#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50)
+#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51)
+#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52)
+#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53)
+#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54)
+#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55)
+#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56)
+#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57)
+#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58)
+#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59)
+#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60)
+#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61)
+#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62)
+#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63)
+
+#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1;
+#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1);
+#define NLOHMANN_JSON_FROM_WITH_DEFAULT(v1) nlohmann_json_t.v1 = nlohmann_json_j.value(#v1, nlohmann_json_default_obj.v1);
+
+/*!
+@brief macro
+@def NLOHMANN_DEFINE_TYPE_INTRUSIVE
+@since version 3.9.0
+*/
+#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...)  \
+    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \
+    friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) }
+
+#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Type, ...)  \
+    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \
+    friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) }
+
+#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(Type, ...)  \
+    friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) }
+
+/*!
+@brief macro
+@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE
+@since version 3.9.0
+*/
+#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...)  \
+    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \
+    inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) }
+
+#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE(Type, ...)  \
+    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) }
+
+#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, ...)  \
+    inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \
+    inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) }
+
+// inspired from https://stackoverflow.com/a/26745591
+// allows to call any std function as if (e.g. with begin):
+// using std::begin; begin(x);
+//
+// it allows using the detected idiom to retrieve the return type
+// of such an expression
+#define NLOHMANN_CAN_CALL_STD_FUNC_IMPL(std_name)                                 \
+    namespace detail {                                                            \
+    using std::std_name;                                                          \
+    \
+    template                                                       \
+    using result_of_##std_name = decltype(std_name(std::declval()...));        \
+    }                                                                             \
+    \
+    namespace detail2 {                                                           \
+    struct std_name##_tag                                                         \
+    {                                                                             \
+    };                                                                            \
+    \
+    template                                                       \
+    std_name##_tag std_name(T&&...);                                              \
+    \
+    template                                                       \
+    using result_of_##std_name = decltype(std_name(std::declval()...));        \
+    \
+    template                                                       \
+    struct would_call_std_##std_name                                              \
+    {                                                                             \
+        static constexpr auto const value = ::nlohmann::detail::                  \
+                                            is_detected_exact::value; \
+    };                                                                            \
+    } /* namespace detail2 */ \
+    \
+    template                                                       \
+    struct would_call_std_##std_name : detail2::would_call_std_##std_name   \
+    {                                                                             \
+    }
+
+#ifndef JSON_USE_IMPLICIT_CONVERSIONS
+    #define JSON_USE_IMPLICIT_CONVERSIONS 1
+#endif
+
+#if JSON_USE_IMPLICIT_CONVERSIONS
+    #define JSON_EXPLICIT
+#else
+    #define JSON_EXPLICIT explicit
+#endif
+
+#ifndef JSON_DISABLE_ENUM_SERIALIZATION
+    #define JSON_DISABLE_ENUM_SERIALIZATION 0
+#endif
+
+#ifndef JSON_USE_GLOBAL_UDLS
+    #define JSON_USE_GLOBAL_UDLS 1
+#endif
+
+#if JSON_HAS_THREE_WAY_COMPARISON
+    #include  // partial_ordering
+#endif
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+///////////////////////////
+// JSON type enumeration //
+///////////////////////////
+
+/*!
+@brief the JSON type enumeration
+
+This enumeration collects the different JSON types. It is internally used to
+distinguish the stored values, and the functions @ref basic_json::is_null(),
+@ref basic_json::is_object(), @ref basic_json::is_array(),
+@ref basic_json::is_string(), @ref basic_json::is_boolean(),
+@ref basic_json::is_number() (with @ref basic_json::is_number_integer(),
+@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()),
+@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and
+@ref basic_json::is_structured() rely on it.
+
+@note There are three enumeration entries (number_integer, number_unsigned, and
+number_float), because the library distinguishes these three types for numbers:
+@ref basic_json::number_unsigned_t is used for unsigned integers,
+@ref basic_json::number_integer_t is used for signed integers, and
+@ref basic_json::number_float_t is used for floating-point numbers or to
+approximate integers which do not fit in the limits of their respective type.
+
+@sa see @ref basic_json::basic_json(const value_t value_type) -- create a JSON
+value with the default value for a given type
+
+@since version 1.0.0
+*/
+enum class value_t : std::uint8_t
+{
+    null,             ///< null value
+    object,           ///< object (unordered set of name/value pairs)
+    array,            ///< array (ordered collection of values)
+    string,           ///< string value
+    boolean,          ///< boolean value
+    number_integer,   ///< number value (signed integer)
+    number_unsigned,  ///< number value (unsigned integer)
+    number_float,     ///< number value (floating-point)
+    binary,           ///< binary array (ordered collection of bytes)
+    discarded         ///< discarded by the parser callback function
+};
+
+/*!
+@brief comparison operator for JSON types
+
+Returns an ordering that is similar to Python:
+- order: null < boolean < number < object < array < string < binary
+- furthermore, each type is not smaller than itself
+- discarded values are not comparable
+- binary is represented as a b"" string in python and directly comparable to a
+  string; however, making a binary array directly comparable with a string would
+  be surprising behavior in a JSON file.
+
+@since version 1.0.0
+*/
+#if JSON_HAS_THREE_WAY_COMPARISON
+    inline std::partial_ordering operator<=>(const value_t lhs, const value_t rhs) noexcept // *NOPAD*
+#else
+    inline bool operator<(const value_t lhs, const value_t rhs) noexcept
+#endif
+{
+    static constexpr std::array order = {{
+            0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */,
+            1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */,
+            6 /* binary */
+        }
+    };
+
+    const auto l_index = static_cast(lhs);
+    const auto r_index = static_cast(rhs);
+#if JSON_HAS_THREE_WAY_COMPARISON
+    if (l_index < order.size() && r_index < order.size())
+    {
+        return order[l_index] <=> order[r_index]; // *NOPAD*
+    }
+    return std::partial_ordering::unordered;
+#else
+    return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index];
+#endif
+}
+
+// GCC selects the built-in operator< over an operator rewritten from
+// a user-defined spaceship operator
+// Clang, MSVC, and ICC select the rewritten candidate
+// (see GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105200)
+#if JSON_HAS_THREE_WAY_COMPARISON && defined(__GNUC__)
+inline bool operator<(const value_t lhs, const value_t rhs) noexcept
+{
+    return std::is_lt(lhs <=> rhs); // *NOPAD*
+}
+#endif
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+/*!
+@brief replace all occurrences of a substring by another string
+
+@param[in,out] s  the string to manipulate; changed so that all
+               occurrences of @a f are replaced with @a t
+@param[in]     f  the substring to replace with @a t
+@param[in]     t  the string to replace @a f
+
+@pre The search string @a f must not be empty. **This precondition is
+enforced with an assertion.**
+
+@since version 2.0.0
+*/
+template
+inline void replace_substring(StringType& s, const StringType& f,
+                              const StringType& t)
+{
+    JSON_ASSERT(!f.empty());
+    for (auto pos = s.find(f);                // find first occurrence of f
+            pos != StringType::npos;          // make sure f was found
+            s.replace(pos, f.size(), t),      // replace with t, and
+            pos = s.find(f, pos + t.size()))  // find next occurrence of f
+    {}
+}
+
+/*!
+ * @brief string escaping as described in RFC 6901 (Sect. 4)
+ * @param[in] s string to escape
+ * @return    escaped string
+ *
+ * Note the order of escaping "~" to "~0" and "/" to "~1" is important.
+ */
+template
+inline StringType escape(StringType s)
+{
+    replace_substring(s, StringType{"~"}, StringType{"~0"});
+    replace_substring(s, StringType{"/"}, StringType{"~1"});
+    return s;
+}
+
+/*!
+ * @brief string unescaping as described in RFC 6901 (Sect. 4)
+ * @param[in] s string to unescape
+ * @return    unescaped string
+ *
+ * Note the order of escaping "~1" to "/" and "~0" to "~" is important.
+ */
+template
+static void unescape(StringType& s)
+{
+    replace_substring(s, StringType{"~1"}, StringType{"/"});
+    replace_substring(s, StringType{"~0"}, StringType{"~"});
+}
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // size_t
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+/// struct to capture the start position of the current token
+struct position_t
+{
+    /// the total number of characters read
+    std::size_t chars_read_total = 0;
+    /// the number of characters read in the current line
+    std::size_t chars_read_current_line = 0;
+    /// the number of lines read
+    std::size_t lines_read = 0;
+
+    /// conversion to size_t to preserve SAX interface
+    constexpr operator size_t() const
+    {
+        return chars_read_total;
+    }
+};
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-FileCopyrightText: 2018 The Abseil Authors
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // array
+#include  // size_t
+#include  // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type
+#include  // index_sequence, make_index_sequence, index_sequence_for
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+template
+using uncvref_t = typename std::remove_cv::type>::type;
+
+#ifdef JSON_HAS_CPP_14
+
+// the following utilities are natively available in C++14
+using std::enable_if_t;
+using std::index_sequence;
+using std::make_index_sequence;
+using std::index_sequence_for;
+
+#else
+
+// alias templates to reduce boilerplate
+template
+using enable_if_t = typename std::enable_if::type;
+
+// The following code is taken from https://github.com/abseil/abseil-cpp/blob/10cb35e459f5ecca5b2ff107635da0bfa41011b4/absl/utility/utility.h
+// which is part of Google Abseil (https://github.com/abseil/abseil-cpp), licensed under the Apache License 2.0.
+
+//// START OF CODE FROM GOOGLE ABSEIL
+
+// integer_sequence
+//
+// Class template representing a compile-time integer sequence. An instantiation
+// of `integer_sequence` has a sequence of integers encoded in its
+// type through its template arguments (which is a common need when
+// working with C++11 variadic templates). `absl::integer_sequence` is designed
+// to be a drop-in replacement for C++14's `std::integer_sequence`.
+//
+// Example:
+//
+//   template< class T, T... Ints >
+//   void user_function(integer_sequence);
+//
+//   int main()
+//   {
+//     // user_function's `T` will be deduced to `int` and `Ints...`
+//     // will be deduced to `0, 1, 2, 3, 4`.
+//     user_function(make_integer_sequence());
+//   }
+template 
+struct integer_sequence
+{
+    using value_type = T;
+    static constexpr std::size_t size() noexcept
+    {
+        return sizeof...(Ints);
+    }
+};
+
+// index_sequence
+//
+// A helper template for an `integer_sequence` of `size_t`,
+// `absl::index_sequence` is designed to be a drop-in replacement for C++14's
+// `std::index_sequence`.
+template 
+using index_sequence = integer_sequence;
+
+namespace utility_internal
+{
+
+template 
+struct Extend;
+
+// Note that SeqSize == sizeof...(Ints). It's passed explicitly for efficiency.
+template 
+struct Extend, SeqSize, 0>
+{
+    using type = integer_sequence < T, Ints..., (Ints + SeqSize)... >;
+};
+
+template 
+struct Extend, SeqSize, 1>
+{
+    using type = integer_sequence < T, Ints..., (Ints + SeqSize)..., 2 * SeqSize >;
+};
+
+// Recursion helper for 'make_integer_sequence'.
+// 'Gen::type' is an alias for 'integer_sequence'.
+template 
+struct Gen
+{
+    using type =
+        typename Extend < typename Gen < T, N / 2 >::type, N / 2, N % 2 >::type;
+};
+
+template 
+struct Gen
+{
+    using type = integer_sequence;
+};
+
+}  // namespace utility_internal
+
+// Compile-time sequences of integers
+
+// make_integer_sequence
+//
+// This template alias is equivalent to
+// `integer_sequence`, and is designed to be a drop-in
+// replacement for C++14's `std::make_integer_sequence`.
+template 
+using make_integer_sequence = typename utility_internal::Gen::type;
+
+// make_index_sequence
+//
+// This template alias is equivalent to `index_sequence<0, 1, ..., N-1>`,
+// and is designed to be a drop-in replacement for C++14's
+// `std::make_index_sequence`.
+template 
+using make_index_sequence = make_integer_sequence;
+
+// index_sequence_for
+//
+// Converts a typename pack into an index sequence of the same length, and
+// is designed to be a drop-in replacement for C++14's
+// `std::index_sequence_for()`
+template 
+using index_sequence_for = make_index_sequence;
+
+//// END OF CODE FROM GOOGLE ABSEIL
+
+#endif
+
+// dispatch utility (taken from ranges-v3)
+template struct priority_tag : priority_tag < N - 1 > {};
+template<> struct priority_tag<0> {};
+
+// taken from ranges-v3
+template
+struct static_const
+{
+    static JSON_INLINE_VARIABLE constexpr T value{};
+};
+
+#ifndef JSON_HAS_CPP_17
+    template
+    constexpr T static_const::value;
+#endif
+
+template
+inline constexpr std::array make_array(Args&& ... args)
+{
+    return std::array {{static_cast(std::forward(args))...}};
+}
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // numeric_limits
+#include  // false_type, is_constructible, is_integral, is_same, true_type
+#include  // declval
+#include  // tuple
+#include  // char_traits
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+#include  // random_access_iterator_tag
+
+// #include 
+
+// #include 
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+namespace detail
+{
+
+template
+struct iterator_types {};
+
+template
+struct iterator_types <
+    It,
+    void_t>
+{
+    using difference_type = typename It::difference_type;
+    using value_type = typename It::value_type;
+    using pointer = typename It::pointer;
+    using reference = typename It::reference;
+    using iterator_category = typename It::iterator_category;
+};
+
+// This is required as some compilers implement std::iterator_traits in a way that
+// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341.
+template
+struct iterator_traits
+{
+};
+
+template
+struct iterator_traits < T, enable_if_t < !std::is_pointer::value >>
+            : iterator_types
+{
+};
+
+template
+struct iterator_traits::value>>
+{
+    using iterator_category = std::random_access_iterator_tag;
+    using value_type = T;
+    using difference_type = ptrdiff_t;
+    using pointer = T*;
+    using reference = T&;
+};
+
+}  // namespace detail
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+
+NLOHMANN_CAN_CALL_STD_FUNC_IMPL(begin);
+
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+
+
+// #include 
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+
+NLOHMANN_CAN_CALL_STD_FUNC_IMPL(end);
+
+NLOHMANN_JSON_NAMESPACE_END
+
+// #include 
+
+// #include 
+
+// #include 
+//     __ _____ _____ _____
+//  __|  |   __|     |   | |  JSON for Modern C++
+// |  |  |__   |  |  | | | |  version 3.11.3
+// |_____|_____|_____|_|___|  https://github.com/nlohmann/json
+//
+// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann 
+// SPDX-License-Identifier: MIT
+
+#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_
+    #define INCLUDE_NLOHMANN_JSON_FWD_HPP_
+
+    #include  // int64_t, uint64_t
+    #include  // map
+    #include  // allocator
+    #include  // string
+    #include  // vector
+
+    // #include 
+
+
+    /*!
+    @brief namespace for Niels Lohmann
+    @see https://github.com/nlohmann
+    @since version 1.0.0
+    */
+    NLOHMANN_JSON_NAMESPACE_BEGIN
+
+    /*!
+    @brief default JSONSerializer template argument
+
+    This serializer ignores the template arguments and uses ADL
+    ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl))
+    for serialization.
+    */
+    template
+    struct adl_serializer;
+
+    /// a class to store JSON values
+    /// @sa https://json.nlohmann.me/api/basic_json/
+    template class ObjectType =
+    std::map,
+    template class ArrayType = std::vector,
+    class StringType = std::string, class BooleanType = bool,
+    class NumberIntegerType = std::int64_t,
+    class NumberUnsignedType = std::uint64_t,
+    class NumberFloatType = double,
+    template class AllocatorType = std::allocator,
+    template class JSONSerializer =
+    adl_serializer,
+    class BinaryType = std::vector, // cppcheck-suppress syntaxError
+    class CustomBaseClass = void>
+    class basic_json;
+
+    /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document
+    /// @sa https://json.nlohmann.me/api/json_pointer/
+    template
+    class json_pointer;
+
+    /*!
+    @brief default specialization
+    @sa https://json.nlohmann.me/api/json/
+    */
+    using json = basic_json<>;
+
+    /// @brief a minimal map-like container that preserves insertion order
+    /// @sa https://json.nlohmann.me/api/ordered_map/
+    template
+    struct ordered_map;
+
+    /// @brief specialization that maintains the insertion order of object keys
+    /// @sa https://json.nlohmann.me/api/ordered_json/
+    using ordered_json = basic_json;
+
+    NLOHMANN_JSON_NAMESPACE_END
+
+#endif  // INCLUDE_NLOHMANN_JSON_FWD_HPP_
+
+
+NLOHMANN_JSON_NAMESPACE_BEGIN
+/*!
+@brief detail namespace with internal helper functions
+
+This namespace collects functions that should not be exposed,
+implementations of some @ref basic_json methods, and meta-programming helpers.
+
+@since version 2.1.0
+*/
+namespace detail
+{
+
+/////////////
+// helpers //
+/////////////
+
+// Note to maintainers:
+//
+// Every trait in this file expects a non CV-qualified type.
+// The only exceptions are in the 'aliases for detected' section
+// (i.e. those of the form: decltype(T::member_function(std::declval())))
+//
+// In this case, T has to be properly CV-qualified to constraint the function arguments
+// (e.g. to_json(BasicJsonType&, const T&))
+
+template struct is_basic_json : std::false_type {};
+
+NLOHMANN_BASIC_JSON_TPL_DECLARATION
+struct is_basic_json : std::true_type {};
+
+// used by exceptions create() member functions
+// true_type for pointer to possibly cv-qualified basic_json or std::nullptr_t
+// false_type otherwise
+template
+struct is_basic_json_context :
+    std::integral_constant < bool,
+    is_basic_json::type>::type>::value
+    || std::is_same::value >
+{};
+
+//////////////////////
+// json_ref helpers //
+//////////////////////
+
+template
+class json_ref;
+
+template
+struct is_json_ref : std::false_type {};
+
+template
+struct is_json_ref> : std::true_type {};
+
+//////////////////////////
+// aliases for detected //
+//////////////////////////
+
+template
+using mapped_type_t = typename T::mapped_type;
+
+template
+using key_type_t = typename T::key_type;
+
+template
+using value_type_t = typename T::value_type;
+
+template
+using difference_type_t = typename T::difference_type;
+
+template
+using pointer_t = typename T::pointer;
+
+template
+using reference_t = typename T::reference;
+
+template
+using iterator_category_t = typename T::iterator_category;
+
+template
+using to_json_function = decltype(T::to_json(std::declval()...));
+
+template
+using from_json_function = decltype(T::from_json(std::declval()...));
+
+template
+using get_template_function = decltype(std::declval().template get());
+
+// trait checking if JSONSerializer::from_json(json const&, udt&) exists
+template
+struct has_from_json : std::false_type {};
+
+// trait checking if j.get is valid
+// use this trait instead of std::is_constructible or std::is_convertible,
+// both rely on, or make use of implicit conversions, and thus fail when T
+// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958)
+template 
+struct is_getable
+{
+    static constexpr bool value = is_detected::value;
+};
+
+template
+struct has_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >>
+{
+    using serializer = typename BasicJsonType::template json_serializer;
+
+    static constexpr bool value =
+        is_detected_exact::value;
+};
+
+// This trait checks if JSONSerializer::from_json(json const&) exists
+// this overload is used for non-default-constructible user-defined-types
+template
+struct has_non_default_from_json : std::false_type {};
+
+template
+struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >>
+{
+    using serializer = typename BasicJsonType::template json_serializer;
+
+    static constexpr bool value =
+        is_detected_exact::value;
+};
+
+// This trait checks if BasicJsonType::json_serializer::to_json exists
+// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion.
+template
+struct has_to_json : std::false_type {};
+
+template
+struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >>
+{
+    using serializer = typename BasicJsonType::template json_serializer;
+
+    static constexpr bool value =
+        is_detected_exact::value;
+};
+
+template
+using detect_key_compare = typename T::key_compare;
+
+template
+struct has_key_compare : std::integral_constant::value> {};
+
+// obtains the actual object key comparator
+template
+struct actual_object_comparator
+{
+    using object_t = typename BasicJsonType::object_t;
+    using object_comparator_t = typename BasicJsonType::default_object_comparator_t;
+    using type = typename std::conditional < has_key_compare::value,
+          typename object_t::key_compare, object_comparator_t>::type;
+};
+
+template
+using actual_object_comparator_t = typename actual_object_comparator::type;
+
+/////////////////
+// char_traits //
+/////////////////
+
+// Primary template of char_traits calls std char_traits
+template
+struct char_traits : std::char_traits
+{};
+
+// Explicitly define char traits for unsigned char since it is not standard
+template<>
+struct char_traits : std::char_traits
+{
+    using char_type = unsigned char;
+    using int_type = uint64_t;
+
+    // Redefine to_int_type function
+    static int_type to_int_type(char_type c) noexcept
+    {
+        return static_cast(c);
+    }
+
+    static char_type to_char_type(int_type i) noexcept
+    {
+        return static_cast(i);
+    }
+
+    static constexpr int_type eof() noexcept
+    {
+        return static_cast(EOF);
+    }
+};
+
+// Explicitly define char traits for signed char since it is not standard
+template<>
+struct char_traits : std::char_traits
+{
+    using char_type = signed char;
+    using int_type = uint64_t;
+
+    // Redefine to_int_type function
+    static int_type to_int_type(char_type c) noexcept
+    {
+        return static_cast(c);
+    }
+
+    static char_type to_char_type(int_type i) noexcept
+    {
+        return static_cast(i);
+    }
+
+    static constexpr int_type eof() noexcept
+    {
+        return static_cast(EOF);
+    }
+};
+
+///////////////////
+// is_ functions //
+///////////////////
+
+// https://en.cppreference.com/w/cpp/types/conjunction
+template struct conjunction : std::true_type { };
+template struct conjunction : B { };
+template
+struct conjunction
+: std::conditional(B::value), conjunction, B>::type {};
+
+// https://en.cppreference.com/w/cpp/types/negation
+template struct negation : std::integral_constant < bool, !B::value > { };
+
+// Reimplementation of is_constructible and is_default_constructible, due to them being broken for
+// std::pair and std::tuple until LWG 2367 fix (see https://cplusplus.github.io/LWG/lwg-defects.html#2367).
+// This causes compile errors in e.g. clang 3.5 or gcc 4.9.
+template 
+struct is_default_constructible : std::is_default_constructible {};
+
+template 
+struct is_default_constructible>
+            : conjunction, is_default_constructible> {};
+
+template 
+struct is_default_constructible>
+            : conjunction, is_default_constructible> {};
+
+template 
+struct is_default_constructible>
+            : conjunction...> {};
+
+template 
+struct is_default_constructible>
+            : conjunction...> {};
+
+template 
+struct is_constructible : std::is_constructible {};
+
+template 
+struct is_constructible> : is_default_constructible> {};
+
+template 
+struct is_constructible> : is_default_constructible> {};
+
+template 
+struct is_constructible> : is_default_constructible> {};
+
+template 
+struct is_constructible> : is_default_constructible> {};
+
+template
+struct is_iterator_traits : std::false_type {};
+
+template
+struct is_iterator_traits>
+{
+  private:
+    using traits = iterator_traits;
+
+  public:
+    static constexpr auto value =
+        is_detected::value &&
+        is_detected::value &&
+        is_detected::value &&
+        is_detected::value &&
+        is_detected::value;
+};
+
+template
+struct is_range
+{
+  private:
+    using t_ref = typename std::add_lvalue_reference::type;
+
+    using iterator = detected_t;
+    using sentinel = detected_t;
+
+    // to be 100% correct, it should use https://en.cppreference.com/w/cpp/iterator/input_or_output_iterator
+    // and https://en.cppreference.com/w/cpp/iterator/sentinel_for
+    // but reimplementing these would be too much work, as a lot of other concepts are used underneath
+    static constexpr auto is_iterator_begin =
+        is_iterator_traits>::value;
+
+  public:
+    static constexpr bool value = !std::is_same::value && !std::is_same::value && is_iterator_begin;
+};
+
+template
+using iterator_t = enable_if_t::value, result_of_begin())>>;
+
+template
+using range_value_t = value_type_t>>;
+
+// The following implementation of is_complete_type is taken from
+// https://blogs.msdn.microsoft.com/vcblog/2015/12/02/partial-support-for-expression-sfinae-in-vs-2015-update-1/
+// and is written by Xiang Fan who agreed to using it in this library.
+
+template
+struct is_complete_type : std::false_type {};
+
+template
+struct is_complete_type : std::true_type {};
+
+template
+struct is_compatible_object_type_impl : std::false_type {};
+
+template
+struct is_compatible_object_type_impl <
+    BasicJsonType, CompatibleObjectType,
+    enable_if_t < is_detected::value&&
+    is_detected::value >>
+{
+    using object_t = typename BasicJsonType::object_t;
+
+    // macOS's is_constructible does not play well with nonesuch...
+    static constexpr bool value =
+        is_constructible::value &&
+        is_constructible::value;
+};
+
+template
+struct is_compatible_object_type
+    : is_compatible_object_type_impl {};
+
+template
+struct is_constructible_object_type_impl : std::false_type {};
+
+template
+struct is_constructible_object_type_impl <
+    BasicJsonType, ConstructibleObjectType,
+    enable_if_t < is_detected::value&&
+    is_detected::value >>
+{
+    using object_t = typename BasicJsonType::object_t;
+
+    static constexpr bool value =
+        (is_default_constructible::value &&
+         (std::is_move_assignable::value ||
+          std::is_copy_assignable::value) &&
+         (is_constructible::value &&
+          std::is_same <
+          typename object_t::mapped_type,
+          typename ConstructibleObjectType::mapped_type >::value)) ||
+        (has_from_json::value ||
+         has_non_default_from_json <
+         BasicJsonType,
+         typename ConstructibleObjectType::mapped_type >::value);
+};
+
+template
+struct is_constructible_object_type
+    : is_constructible_object_type_impl {};
+
+template
+struct is_compatible_string_type
+{
+    static constexpr auto value =
+        is_constructible::value;
+};
+
+template
+struct is_constructible_string_type
+{
+    // launder type through decltype() to fix compilation failure on ICPC
+#ifdef __INTEL_COMPILER
+    using laundered_type = decltype(std::declval());
+#else
+    using laundered_type = ConstructibleStringType;
+#endif
+
+    static constexpr auto value =
+        conjunction <
+        is_constructible,
+        is_detected_exact>::value;
+};
+
+template
+struct is_compatible_array_type_impl : std::false_type {};
+
+template
+struct is_compatible_array_type_impl <
+    BasicJsonType, CompatibleArrayType,
+    enable_if_t <
+    is_detected::value&&
+    is_iterator_traits>>::value&&
+// special case for types like std::filesystem::path whose iterator's value_type are themselves
+// c.f. https://github.com/nlohmann/json/pull/3073
+    !std::is_same>::value >>
+{
+    static constexpr bool value =
+        is_constructible>::value;
+};
+
+template
+struct is_compatible_array_type
+    : is_compatible_array_type_impl {};
+
+template
+struct is_constructible_array_type_impl : std::false_type {};
+
+template
+struct is_constructible_array_type_impl <
+    BasicJsonType, ConstructibleArrayType,
+    enable_if_t::value >>
+            : std::true_type {};
+
+template
+struct is_constructible_array_type_impl <
+    BasicJsonType, ConstructibleArrayType,
+    enable_if_t < !std::is_same::value&&
+    !is_compatible_string_type::value&&
+    is_default_constructible::value&&
+(std::is_move_assignable::value ||
+ std::is_copy_assignable::value)&&
+is_detected::value&&
+is_iterator_traits>>::value&&
+is_detected::value&&
+// special case for types like std::filesystem::path whose iterator's value_type are themselves
+// c.f. https://github.com/nlohmann/json/pull/3073
+!std::is_same>::value&&
+        is_complete_type <
+        detected_t>::value >>
+{
+    using value_type = range_value_t;
+
+    static constexpr bool value =
+        std::is_same::value ||
+        has_from_json::value ||
+        has_non_default_from_json <
+        BasicJsonType,
+        value_type >::value;
+};
+
+template
+struct is_constructible_array_type
+    : is_constructible_array_type_impl {};
+
+template
+struct is_compatible_integer_type_impl : std::false_type {};
+
+template
+struct is_compatible_integer_type_impl <
+    RealIntegerType, CompatibleNumberIntegerType,
+    enable_if_t < std::is_integral::value&&
+    std::is_integral::value&&
+    !std::is_same::value >>
+{
+    // is there an assert somewhere on overflows?
+    using RealLimits = std::numeric_limits;
+    using CompatibleLimits = std::numeric_limits;
+
+    static constexpr auto value =
+        is_constructible::value &&
+        CompatibleLimits::is_integer &&
+        RealLimits::is_signed == CompatibleLimits::is_signed;
+};
+
+template
+struct is_compatible_integer_type
+    : is_compatible_integer_type_impl {};
+
+template
+struct is_compatible_type_impl: std::false_type {};
+
+template
+struct is_compatible_type_impl <
+    BasicJsonType, CompatibleType,
+    enable_if_t::value >>
+{
+    static constexpr bool value =
+        has_to_json::value;
+};
+
+template
+struct is_compatible_type
+    : is_compatible_type_impl {};
+
+template
+struct is_constructible_tuple : std::false_type {};
+
+template
+struct is_constructible_tuple> : conjunction...> {};
+
+template
+struct is_json_iterator_of : std::false_type {};
+
+template
+struct is_json_iterator_of : std::true_type {};
+
+template
+struct is_json_iterator_of : std::true_type
+{};
+
+// checks if a given type T is a template specialization of Primary
+template