Compare commits
1 Commits
pdevine/lo
...
bmizerany/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c1204b686 |
@@ -106,11 +106,9 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
|
||||||
|
|
||||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||||
install(TARGETS ggml-hip
|
install(TARGETS ggml-hip
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
|
|||||||
RUN yum install -y yum-utils \
|
RUN yum install -y yum-utils \
|
||||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
|
|
||||||
@@ -86,11 +86,10 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
ARG GOVERSION=1.23.4
|
||||||
COPY go.mod go.sum .
|
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
RUN go mod download
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
COPY . .
|
COPY . .
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
|
|||||||
23
README.md
23
README.md
@@ -1,5 +1,5 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://ollama.com">
|
<a href="https://ollama.com" />
|
||||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
@@ -54,11 +54,6 @@ Here are some example models that can be downloaded:
|
|||||||
|
|
||||||
| Model | Parameters | Size | Download |
|
| Model | Parameters | Size | Download |
|
||||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||||
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
|
||||||
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
|
||||||
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
|
||||||
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
|
||||||
| QwQ | 32B | 20GB | `ollama run qwq` |
|
|
||||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||||
@@ -69,7 +64,10 @@ Here are some example models that can be downloaded:
|
|||||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||||
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
||||||
|
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
||||||
|
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
||||||
|
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||||
@@ -77,7 +75,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
@@ -277,7 +275,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
- [Hollama](https://github.com/fmaclen/hollama)
|
- [Hollama](https://github.com/fmaclen/hollama)
|
||||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||||
@@ -390,8 +387,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
|
||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -435,7 +430,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
|
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
@@ -513,13 +507,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
|
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
|
||||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
|
||||||
|
|
||||||
### Extensions & Plugins
|
### Extensions & Plugins
|
||||||
|
|
||||||
@@ -565,14 +556,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
|
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
|
||||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||||
|
|||||||
14
api/types.go
14
api/types.go
@@ -349,7 +349,6 @@ type ShowResponse struct {
|
|||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
Tensors []Tensor `json:"tensors,omitempty"`
|
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,9 +361,9 @@ type CopyRequest struct {
|
|||||||
// PullRequest is the request passed to [Client.Pull].
|
// PullRequest is the request passed to [Client.Pull].
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
|
Insecure bool `json:"insecure,omitempty"`
|
||||||
Username string `json:"username"` // Deprecated: ignored
|
Username string `json:"username"`
|
||||||
Password string `json:"password"` // Deprecated: ignored
|
Password string `json:"password"`
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
@@ -468,13 +467,6 @@ type ModelDetails struct {
|
|||||||
QuantizationLevel string `json:"quantization_level"`
|
QuantizationLevel string `json:"quantization_level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tensor describes the metadata for a given tensor.
|
|
||||||
type Tensor struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Shape []uint64 `json:"shape"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Metrics) Summary() {
|
func (m *Metrics) Summary() {
|
||||||
if m.TotalDuration > 0 {
|
if m.TotalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
|
|||||||
67
cmd/cmd.go
67
cmd/cmd.go
@@ -18,7 +18,6 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -35,6 +34,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
"github.com/ollama/ollama/runner"
|
"github.com/ollama/ollama/runner"
|
||||||
@@ -256,7 +256,6 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -339,16 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(info.ProjectorInfo) != 0 {
|
// TODO(jessegross): We should either find another way to know if this is
|
||||||
opts.MultiModal = true
|
// a vision model or remove the logic. Also consider that other modalities will
|
||||||
}
|
// need different behavior anyways.
|
||||||
for k := range info.ModelInfo {
|
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
|
||||||
if strings.Contains(k, ".vision.") {
|
|
||||||
opts.MultiModal = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
opts.ParentModel = info.Details.ParentModel
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
@@ -569,9 +562,8 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
parameters, errParams := cmd.Flags().GetBool("parameters")
|
parameters, errParams := cmd.Flags().GetBool("parameters")
|
||||||
system, errSystem := cmd.Flags().GetBool("system")
|
system, errSystem := cmd.Flags().GetBool("system")
|
||||||
template, errTemplate := cmd.Flags().GetBool("template")
|
template, errTemplate := cmd.Flags().GetBool("template")
|
||||||
verbose, errVerbose := cmd.Flags().GetBool("verbose")
|
|
||||||
|
|
||||||
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
|
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
|
||||||
if boolErr != nil {
|
if boolErr != nil {
|
||||||
return errors.New("error retrieving flags")
|
return errors.New("error retrieving flags")
|
||||||
}
|
}
|
||||||
@@ -609,7 +601,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := api.ShowRequest{Name: args[0], Verbose: verbose}
|
req := api.ShowRequest{Name: args[0]}
|
||||||
resp, err := client.Show(cmd.Context(), &req)
|
resp, err := client.Show(cmd.Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -632,10 +624,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return showInfo(resp, verbose, os.Stdout)
|
return showInfo(resp, os.Stdout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
||||||
tableRender := func(header string, rows func() [][]string) {
|
tableRender := func(header string, rows func() [][]string) {
|
||||||
fmt.Fprintln(w, " ", header)
|
fmt.Fprintln(w, " ", header)
|
||||||
table := tablewriter.NewWriter(w)
|
table := tablewriter.NewWriter(w)
|
||||||
@@ -692,45 +684,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.ModelInfo != nil && verbose {
|
|
||||||
tableRender("Metadata", func() (rows [][]string) {
|
|
||||||
keys := make([]string, 0, len(resp.ModelInfo))
|
|
||||||
for k := range resp.ModelInfo {
|
|
||||||
keys = append(keys, k)
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
for _, k := range keys {
|
|
||||||
var v string
|
|
||||||
switch vData := resp.ModelInfo[k].(type) {
|
|
||||||
case string:
|
|
||||||
v = vData
|
|
||||||
case float64:
|
|
||||||
v = fmt.Sprintf("%g", vData)
|
|
||||||
case []any:
|
|
||||||
n := 3
|
|
||||||
if len(vData) < n {
|
|
||||||
n = len(vData)
|
|
||||||
}
|
|
||||||
v = fmt.Sprintf("%v", vData[:n])
|
|
||||||
default:
|
|
||||||
v = fmt.Sprintf("%T", vData)
|
|
||||||
}
|
|
||||||
rows = append(rows, []string{"", k, v})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Tensors) > 0 && verbose {
|
|
||||||
tableRender("Tensors", func() (rows [][]string) {
|
|
||||||
for _, t := range resp.Tensors {
|
|
||||||
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
head := func(s string, n int) (rows [][]string) {
|
head := func(s string, n int) (rows [][]string) {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
for scanner.Scan() && (len(rows) < n || n < 0) {
|
for scanner.Scan() && (len(rows) < n || n < 0) {
|
||||||
@@ -1237,7 +1190,6 @@ func NewCLI() *cobra.Command {
|
|||||||
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
||||||
showCmd.Flags().Bool("template", false, "Show template of a model")
|
showCmd.Flags().Bool("template", false, "Show template of a model")
|
||||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||||
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
|
|
||||||
|
|
||||||
runCmd := &cobra.Command{
|
runCmd := &cobra.Command{
|
||||||
Use: "run MODEL [PROMPT]",
|
Use: "run MODEL [PROMPT]",
|
||||||
@@ -1322,6 +1274,7 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
runnerCmd := &cobra.Command{
|
runnerCmd := &cobra.Command{
|
||||||
Use: "runner",
|
Use: "runner",
|
||||||
|
Short: llama.PrintSystemInfo(),
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
return runner.Execute(os.Args[1:])
|
return runner.Execute(os.Args[1:])
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, false, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, false, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,56 +68,6 @@ func TestShowInfo(t *testing.T) {
|
|||||||
embedding length 0
|
embedding length 0
|
||||||
quantization FP16
|
quantization FP16
|
||||||
|
|
||||||
`
|
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("verbose model", func(t *testing.T) {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := showInfo(&api.ShowResponse{
|
|
||||||
Details: api.ModelDetails{
|
|
||||||
Family: "test",
|
|
||||||
ParameterSize: "8B",
|
|
||||||
QuantizationLevel: "FP16",
|
|
||||||
},
|
|
||||||
Parameters: `
|
|
||||||
stop up`,
|
|
||||||
ModelInfo: map[string]any{
|
|
||||||
"general.architecture": "test",
|
|
||||||
"general.parameter_count": float64(8_000_000_000),
|
|
||||||
"test.context_length": float64(1000),
|
|
||||||
"test.embedding_length": float64(11434),
|
|
||||||
},
|
|
||||||
Tensors: []api.Tensor{
|
|
||||||
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
|
|
||||||
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
|
|
||||||
},
|
|
||||||
}, true, &b); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expect := ` Model
|
|
||||||
architecture test
|
|
||||||
parameters 8B
|
|
||||||
context length 1000
|
|
||||||
embedding length 11434
|
|
||||||
quantization FP16
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
stop up
|
|
||||||
|
|
||||||
Metadata
|
|
||||||
general.architecture test
|
|
||||||
general.parameter_count 8e+09
|
|
||||||
test.context_length 1000
|
|
||||||
test.embedding_length 11434
|
|
||||||
|
|
||||||
Tensors
|
|
||||||
blk.0.attn_k.weight BF16 [42 3117]
|
|
||||||
blk.0.attn_q.weight FP16 [3117 42]
|
|
||||||
|
|
||||||
`
|
`
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
@@ -139,7 +89,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
stop you
|
stop you
|
||||||
stop up
|
stop up
|
||||||
temperature 99`,
|
temperature 99`,
|
||||||
}, false, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +126,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
"clip.vision.embedding_length": float64(0),
|
"clip.vision.embedding_length": float64(0),
|
||||||
"clip.vision.projection_dim": float64(0),
|
"clip.vision.projection_dim": float64(0),
|
||||||
},
|
},
|
||||||
}, false, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,7 +159,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Ahoy, matey!
|
Ahoy, matey!
|
||||||
Weigh anchor!
|
Weigh anchor!
|
||||||
`,
|
`,
|
||||||
}, false, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +188,7 @@ Weigh anchor!
|
|||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
License: license,
|
License: license,
|
||||||
}, false, &b); err != nil {
|
}, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -195,10 +195,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
if strings.Contains(err.Error(), "not found") {
|
|
||||||
fmt.Printf("error: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -347,7 +343,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "info":
|
case "info":
|
||||||
_ = showInfo(resp, false, os.Stderr)
|
_ = showInfo(resp, os.Stderr)
|
||||||
case "license":
|
case "license":
|
||||||
if resp.License == "" {
|
if resp.License == "" {
|
||||||
fmt.Println("No license was specified for this model.")
|
fmt.Println("No license was specified for this model.")
|
||||||
|
|||||||
@@ -13,13 +13,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ModelParameters struct {
|
type ModelParameters struct {
|
||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
TextModel TextParameters `json:"text_config"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextParameters struct {
|
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type AdapterParameters struct {
|
type AdapterParameters struct {
|
||||||
@@ -190,8 +185,6 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
conv = &gemmaModel{}
|
conv = &gemmaModel{}
|
||||||
case "Gemma2ForCausalLM":
|
case "Gemma2ForCausalLM":
|
||||||
conv = &gemma2Model{}
|
conv = &gemma2Model{}
|
||||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
|
||||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
|
||||||
case "Phi3ForCausalLM":
|
case "Phi3ForCausalLM":
|
||||||
conv = &phi3Model{}
|
conv = &phi3Model{}
|
||||||
case "Qwen2ForCausalLM":
|
case "Qwen2ForCausalLM":
|
||||||
@@ -220,14 +213,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
vocabSize := int(p.VocabSize)
|
vocabSize := int(p.VocabSize)
|
||||||
if vocabSize == 0 {
|
|
||||||
tVocabSize := int(p.TextModel.VocabSize)
|
|
||||||
vocabSize = tVocabSize
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case vocabSize == 0:
|
|
||||||
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
|
||||||
case vocabSize > len(t.Vocabulary.Tokens):
|
case vocabSize > len(t.Vocabulary.Tokens):
|
||||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
if strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||||
t.SetRepacker(p.addOne)
|
t.SetRepacker(p.addOne)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,142 +0,0 @@
|
|||||||
package convert
|
|
||||||
|
|
||||||
import (
|
|
||||||
"cmp"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
|
||||||
)
|
|
||||||
|
|
||||||
type gemma3Model struct {
|
|
||||||
gemmaModel
|
|
||||||
Architecture string
|
|
||||||
TextModel struct {
|
|
||||||
HeadDim uint32 `json:"head_dim"`
|
|
||||||
HiddenSize uint32 `json:"hidden_size"`
|
|
||||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
|
||||||
IntermediateSize uint32 `json:"intermediate_size"`
|
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
|
||||||
} `json:"text_config"`
|
|
||||||
VisionModel struct {
|
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
|
|
||||||
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
|
|
||||||
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
|
|
||||||
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
|
|
||||||
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
|
|
||||||
ImageSize uint32 `json:"image_size"` // image_size 560
|
|
||||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
|
||||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
|
||||||
} `json:"vision_config"`
|
|
||||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
|
||||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
|
||||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
|
||||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
|
||||||
HeadDim uint32 `json:"head_dim"`
|
|
||||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
|
||||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
|
||||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
|
||||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
gemma4BLayerCount = 34
|
|
||||||
gemma12BLayerCount = 48
|
|
||||||
gemma27BLayerCount = 62
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
|
||||||
kv := p.ModelParameters.KV(t)
|
|
||||||
kv["general.architecture"] = "gemma3"
|
|
||||||
|
|
||||||
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
|
|
||||||
kv["gemma3.block_count"] = numBlocks
|
|
||||||
|
|
||||||
var (
|
|
||||||
numHeads uint32
|
|
||||||
numKVHeads uint32
|
|
||||||
)
|
|
||||||
|
|
||||||
switch numBlocks {
|
|
||||||
case gemma4BLayerCount:
|
|
||||||
numHeads = 8
|
|
||||||
numKVHeads = 4
|
|
||||||
case gemma12BLayerCount:
|
|
||||||
numHeads = 16
|
|
||||||
numKVHeads = 8
|
|
||||||
case gemma27BLayerCount:
|
|
||||||
numHeads = 32
|
|
||||||
numKVHeads = 16
|
|
||||||
default:
|
|
||||||
numHeads = p.NumAttentionHeads
|
|
||||||
numKVHeads = p.NumKeyValueHeads
|
|
||||||
}
|
|
||||||
|
|
||||||
kv["gemma3.attention.head_count"] = numHeads
|
|
||||||
kv["gemma3.attention.head_count_kv"] = numKVHeads
|
|
||||||
|
|
||||||
switch p.Architecture {
|
|
||||||
case "Gemma3ForCausalLM":
|
|
||||||
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
|
|
||||||
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
|
||||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
|
||||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
|
||||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
|
||||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
|
||||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
|
||||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
|
||||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
|
||||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
|
||||||
default:
|
|
||||||
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
|
|
||||||
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
|
|
||||||
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
|
|
||||||
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
|
|
||||||
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
|
||||||
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
|
||||||
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
|
||||||
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
|
||||||
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
|
|
||||||
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
|
|
||||||
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
|
||||||
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
|
|
||||||
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
|
||||||
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.MultiModalTokensPerImage > 0 {
|
|
||||||
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
|
|
||||||
}
|
|
||||||
|
|
||||||
return kv
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *gemma3Model) Replacements() []string {
|
|
||||||
return []string{
|
|
||||||
"lm_head", "output",
|
|
||||||
"model.embed_tokens", "token_embd",
|
|
||||||
"model.norm", "output_norm",
|
|
||||||
"vision_tower.vision_model.embeddings", "v",
|
|
||||||
"vision_tower.vision_model", "v",
|
|
||||||
"vision_model.vision_model.embeddings", "v",
|
|
||||||
"vision_model.vision_model", "v",
|
|
||||||
"language_model.", "",
|
|
||||||
"model.layers", "blk",
|
|
||||||
"encoder.layers", "blk",
|
|
||||||
"input_layernorm", "attn_norm",
|
|
||||||
"self_attn.q_proj", "attn_q",
|
|
||||||
"self_attn.q_norm", "attn_q_norm",
|
|
||||||
"self_attn.k_proj", "attn_k",
|
|
||||||
"self_attn.k_norm", "attn_k_norm",
|
|
||||||
"self_attn.v_proj", "attn_v",
|
|
||||||
"self_attn.o_proj", "attn_output",
|
|
||||||
"self_attn.out_proj", "attn_output",
|
|
||||||
"mlp.gate_proj", "ffn_gate",
|
|
||||||
"mlp.down_proj", "ffn_down",
|
|
||||||
"mlp.up_proj", "ffn_up",
|
|
||||||
"post_attention_layernorm", "post_attention_norm",
|
|
||||||
"pre_feedforward_layernorm", "ffn_norm",
|
|
||||||
"post_feedforward_layernorm", "post_ffw_norm",
|
|
||||||
"input_projection_weight", "input_projection.weight",
|
|
||||||
"multi_modal_projector", "mm",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,9 +6,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
@@ -17,8 +15,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||||
slog.Debug("using spm vocabulary")
|
|
||||||
|
|
||||||
ast, err := parseAdditionalSpecialTokens(fsys)
|
ast, err := parseAdditionalSpecialTokens(fsys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -47,19 +43,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
v.Types = append(v.Types, int32(t))
|
v.Types = append(v.Types, int32(t))
|
||||||
default:
|
default:
|
||||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
|
if slices.Contains(ast, piece.GetPiece()) {
|
||||||
// temporary fix to handle gemma3 broken configs
|
|
||||||
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range ast {
|
|
||||||
if t.Content == piece.GetPiece() {
|
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
v.Types = append(v.Types, tt)
|
v.Types = append(v.Types, tt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -91,16 +78,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return cmp.Compare(i.id, j.id)
|
return cmp.Compare(i.id, j.id)
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, t := range ts {
|
n := len(v.Tokens)
|
||||||
if t.id < len(v.Tokens) {
|
for i, t := range ts {
|
||||||
if v.Tokens[t.id] == t.content {
|
if t.id != i+n {
|
||||||
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
|
return nil, fmt.Errorf("invalid token id: %d", t.id)
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
|
|
||||||
}
|
|
||||||
if t.id != len(v.Tokens) {
|
|
||||||
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
v.Tokens = append(v.Tokens, t.content)
|
v.Tokens = append(v.Tokens, t.content)
|
||||||
@@ -111,15 +92,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type specialToken struct {
|
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
||||||
Content string `json:"content"`
|
|
||||||
Lstrip bool `json:"lstrip"`
|
|
||||||
Normalized bool `json:"normalized"`
|
|
||||||
Rstrip bool `json:"rstrip"`
|
|
||||||
SingleWord bool `json:"single_word"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
|
|
||||||
f, err := fsys.Open("special_tokens_map.json")
|
f, err := fsys.Open("special_tokens_map.json")
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -129,43 +102,12 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var m struct {
|
var m struct {
|
||||||
AdditionalSpecialTokens any `json:"additional_special_tokens"`
|
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var ast []specialToken
|
return m.AdditionalSpecialTokens, nil
|
||||||
|
|
||||||
switch st := m.AdditionalSpecialTokens.(type) {
|
|
||||||
case []string:
|
|
||||||
for _, s := range st {
|
|
||||||
ast = append(ast, specialToken{Content: s})
|
|
||||||
}
|
|
||||||
case []any:
|
|
||||||
for _, s := range st {
|
|
||||||
// marshal and unmarshal the object to get the special token
|
|
||||||
tMap := s.(map[string]any)
|
|
||||||
data, err := json.Marshal(tMap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var token specialToken
|
|
||||||
err = json.Unmarshal(data, &token)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ast = append(ast, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("spm tokenizer", "additional tokens", ast)
|
|
||||||
|
|
||||||
return ast, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -118,35 +118,6 @@ To run tests, use `go test`:
|
|||||||
go test ./...
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
|
||||||
> "synctest" package in go1.24.
|
|
||||||
>
|
|
||||||
> If you do not have the "synctest" package enabled, you will not see build or
|
|
||||||
> test failures resulting from your change(s), if any, locally, but CI will
|
|
||||||
> break.
|
|
||||||
>
|
|
||||||
> If you see failures in CI, you can either keep pushing changes to see if the
|
|
||||||
> CI build passes, or you can enable the "synctest" package locally to see the
|
|
||||||
> failures before pushing.
|
|
||||||
>
|
|
||||||
> To enable the "synctest" package for testing, run the following command:
|
|
||||||
>
|
|
||||||
> ```shell
|
|
||||||
> GOEXPERIMENT=synctest go test ./...
|
|
||||||
> ```
|
|
||||||
>
|
|
||||||
> If you wish to enable synctest for all go commands, you can set the
|
|
||||||
> `GOEXPERIMENT` environment variable in your shell profile or by using:
|
|
||||||
>
|
|
||||||
> ```shell
|
|
||||||
> go env -w GOEXPERIMENT=synctest
|
|
||||||
> ```
|
|
||||||
>
|
|
||||||
> Which will enable the "synctest" package for all go commands without needing
|
|
||||||
> to set it for all shell sessions.
|
|
||||||
>
|
|
||||||
> The synctest package is not required for production builds.
|
|
||||||
|
|
||||||
## Library detection
|
## Library detection
|
||||||
|
|
||||||
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
|
By default, Ollama uses a context window size of 2048 tokens.
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ RestartSec=3
|
|||||||
Environment="PATH=$PATH"
|
Environment="PATH=$PATH"
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=multi-user.target
|
WantedBy=default.target
|
||||||
```
|
```
|
||||||
|
|
||||||
Then start the service:
|
Then start the service:
|
||||||
|
|||||||
@@ -81,11 +81,9 @@ help you keep up to date.
|
|||||||
|
|
||||||
If you'd like to install or integrate Ollama as a service, a standalone
|
If you'd like to install or integrate Ollama as a service, a standalone
|
||||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||||
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
||||||
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
Ollama in existing applications, or running it as a system service via `ollama
|
||||||
same directory. This allows for embedding Ollama in existing applications, or
|
serve` with tools such as [NSSM](https://nssm.cc/).
|
||||||
running it as a system service via `ollama serve` with tools such as
|
|
||||||
[NSSM](https://nssm.cc/).
|
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||||
|
|||||||
@@ -124,19 +124,6 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
|
||||||
r := keyValue(kv, key, &array{})
|
|
||||||
s := make([]float32, r.size)
|
|
||||||
for i := range r.size {
|
|
||||||
s[i] = float32(r.values[i].(float32))
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (kv KV) OllamaEngineRequired() bool {
|
|
||||||
return kv.Architecture() == "gemma3"
|
|
||||||
}
|
|
||||||
|
|
||||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||||
key = kv.Architecture() + "." + key
|
key = kv.Architecture() + "." + key
|
||||||
@@ -327,10 +314,6 @@ func (t Tensor) Size() uint64 {
|
|||||||
return t.parameters() * t.typeSize() / t.blockSize()
|
return t.parameters() * t.typeSize() / t.blockSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tensor) Type() string {
|
|
||||||
return fileType(t.Kind).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type container interface {
|
type container interface {
|
||||||
Name() string
|
Name() string
|
||||||
Decode(io.ReadSeeker) (model, error)
|
Decode(io.ReadSeeker) (model, error)
|
||||||
@@ -493,7 +476,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
// vocab graph
|
// vocab graph
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
case "gemma", "gemma2", "gemma3":
|
case "gemma", "gemma2":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||||
@@ -582,43 +565,6 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
|
||||||
switch llm.KV().Architecture() {
|
|
||||||
case "mllama":
|
|
||||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
|
||||||
weights += layer.Size()
|
|
||||||
}
|
|
||||||
|
|
||||||
kv := func(n string) uint64 {
|
|
||||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
|
||||||
return uint64(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
imageSize := kv("image_size")
|
|
||||||
|
|
||||||
maxNumTiles := kv("max_num_tiles")
|
|
||||||
embeddingLength := kv("embedding_length")
|
|
||||||
headCount := kv("attention.head_count")
|
|
||||||
|
|
||||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
|
||||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
|
||||||
numPatches++
|
|
||||||
}
|
|
||||||
|
|
||||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
|
||||||
|
|
||||||
graphSize = 4 * (8 +
|
|
||||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
|
||||||
embeddingLength*numPatches*maxNumTiles +
|
|
||||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
|
||||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
|
||||||
}
|
|
||||||
return weights, graphSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// SupportsKVCacheType checks if the requested cache type is supported
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -24,7 +24,7 @@ require (
|
|||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
golang.org/x/tools v0.30.0
|
gonum.org/v1/gonum v0.15.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -44,7 +44,6 @@ require (
|
|||||||
github.com/xtgo/set v1.0.0 // indirect
|
github.com/xtgo/set v1.0.0 // indirect
|
||||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||||
gonum.org/v1/gonum v0.15.0 // indirect
|
|
||||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -309,8 +309,6 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
|||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
|
||||||
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -30,17 +29,6 @@ type Cache interface {
|
|||||||
// cache implementation used.
|
// cache implementation used.
|
||||||
Put(ctx ml.Context, key, value ml.Tensor)
|
Put(ctx ml.Context, key, value ml.Tensor)
|
||||||
|
|
||||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
|
||||||
// the output of the cache to work better with specific kernels. If not called,
|
|
||||||
// the backend settings will be used. This works well when calling Attention.
|
|
||||||
//
|
|
||||||
// The config can be overridden by models, especially if they require vanilla
|
|
||||||
// output when implementing their own version of attention. To do this, pass
|
|
||||||
// an empty ml.CacheConfig.
|
|
||||||
//
|
|
||||||
// Most models will not need to use this.
|
|
||||||
SetConfig(ml.CacheConfig)
|
|
||||||
|
|
||||||
// ** cache management **
|
// ** cache management **
|
||||||
|
|
||||||
// Init sets up runtime parameters
|
// Init sets up runtime parameters
|
||||||
@@ -52,7 +40,7 @@ type Cache interface {
|
|||||||
// StartForward is called before the start of the model's forward pass.
|
// StartForward is called before the start of the model's forward pass.
|
||||||
// For each token in the coming batch, there must be a corresponding
|
// For each token in the coming batch, there must be a corresponding
|
||||||
// entry in positions and seqs.
|
// entry in positions and seqs.
|
||||||
StartForward(ctx ml.Context, opts input.Options) error
|
StartForward(ctx ml.Context, positions []int32, seqs []int) error
|
||||||
|
|
||||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||||
@@ -23,11 +22,6 @@ type Causal struct {
|
|||||||
Capacity int32
|
Capacity int32
|
||||||
windowSize int32
|
windowSize int32
|
||||||
|
|
||||||
opts CausalOptions
|
|
||||||
|
|
||||||
// config controls mostly backend-specific optimizations
|
|
||||||
config *ml.CacheConfig
|
|
||||||
|
|
||||||
// ** current forward pass **
|
// ** current forward pass **
|
||||||
|
|
||||||
// the active layer for Get and Put
|
// the active layer for Get and Put
|
||||||
@@ -45,12 +39,6 @@ type Causal struct {
|
|||||||
// locations in the cache that are needed for this batch
|
// locations in the cache that are needed for this batch
|
||||||
curCellRange cellRange
|
curCellRange cellRange
|
||||||
|
|
||||||
// curSequences is the sequences corresponding to this pass's entries in the cache
|
|
||||||
curSequences []int
|
|
||||||
|
|
||||||
// curPositions is the positions corresponding to this pass's entries in the cache
|
|
||||||
curPositions []int32
|
|
||||||
|
|
||||||
// ** cache metadata **
|
// ** cache metadata **
|
||||||
|
|
||||||
// for each possible location in the cache, stores the position and set of sequences
|
// for each possible location in the cache, stores the position and set of sequences
|
||||||
@@ -64,8 +52,8 @@ type Causal struct {
|
|||||||
|
|
||||||
shiftFn shiftFn
|
shiftFn shiftFn
|
||||||
backend ml.Backend
|
backend ml.Backend
|
||||||
ctxs map[int]ml.Context
|
cacheCtx ml.Context
|
||||||
keys, values map[int]ml.Tensor
|
keys, values []ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
type cacheCell struct {
|
type cacheCell struct {
|
||||||
@@ -79,72 +67,28 @@ type cellRange struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewCausalCache(shift shiftFn) *Causal {
|
func NewCausalCache(shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
|
||||||
windowSize: math.MaxInt32,
|
|
||||||
shiftFn: shift,
|
|
||||||
ctxs: make(map[int]ml.Context),
|
|
||||||
keys: make(map[int]ml.Tensor),
|
|
||||||
values: make(map[int]ml.Tensor),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||||
return &Causal{
|
return &Causal{windowSize: windowSize, shiftFn: shift}
|
||||||
windowSize: windowSize,
|
|
||||||
shiftFn: shift,
|
|
||||||
ctxs: make(map[int]ml.Context),
|
|
||||||
keys: make(map[int]ml.Tensor),
|
|
||||||
values: make(map[int]ml.Tensor),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||||
if c.config == nil {
|
|
||||||
var config ml.CacheConfig
|
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
||||||
config = cc.CacheConfig()
|
|
||||||
}
|
|
||||||
c.config = &config
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.CachePadding == 0 {
|
|
||||||
c.config.CachePadding = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.MaskBatchPadding == 0 {
|
|
||||||
c.config.MaskBatchPadding = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.MaskDType == ml.DTypeOther {
|
|
||||||
c.config.MaskDType = ml.DTypeF32
|
|
||||||
}
|
|
||||||
|
|
||||||
c.DType = dtype
|
c.DType = dtype
|
||||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
c.Capacity = capacity
|
||||||
c.cells = make([]cacheCell, c.Capacity)
|
c.cells = make([]cacheCell, capacity)
|
||||||
c.cellRanges = make(map[int]cellRange)
|
c.cellRanges = make(map[int]cellRange)
|
||||||
c.backend = backend
|
c.backend = backend
|
||||||
}
|
c.cacheCtx = backend.NewContext()
|
||||||
|
|
||||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
|
||||||
if c.config != nil {
|
|
||||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.config = &config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Close() {
|
func (c *Causal) Close() {
|
||||||
for _, ctx := range c.ctxs {
|
c.cacheCtx.Close()
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||||
c.curBatchSize = len(opts.Positions)
|
c.curBatchSize = len(positions)
|
||||||
c.curSequences = opts.Sequences
|
|
||||||
c.curPositions = opts.Positions
|
|
||||||
c.opts.Except = nil
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
@@ -157,8 +101,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.curCellRange = newRange()
|
c.curCellRange = newRange()
|
||||||
for i, pos := range opts.Positions {
|
for i, pos := range positions {
|
||||||
seq := opts.Sequences[i]
|
seq := seqs[i]
|
||||||
|
|
||||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||||
|
|
||||||
@@ -183,7 +127,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
c.cellRanges[seq] = seqRange
|
c.cellRanges[seq] = seqRange
|
||||||
}
|
}
|
||||||
|
|
||||||
c.curMask, err = c.buildMask(ctx)
|
c.curMask, err = c.buildMask(ctx, positions, seqs)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -213,91 +157,36 @@ func (c *Causal) findStartLoc() (int, error) {
|
|||||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||||
}
|
}
|
||||||
|
|
||||||
func roundDown(length, pad int) int {
|
|
||||||
return (length / pad) * pad
|
|
||||||
}
|
|
||||||
|
|
||||||
func roundUp(length, pad int) int {
|
|
||||||
return ((length + pad - 1) / pad) * pad
|
|
||||||
}
|
|
||||||
|
|
||||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||||
// token in the history should apply. This is based on both the sequence and causality (the
|
// token in the history should apply. This is based on both the sequence and causality (the
|
||||||
// position of the history is not ahead of the token in the batch).
|
// position of the history is not ahead of the token in the batch).
|
||||||
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||||
// Align and pad the two dimensions as required by the backend
|
// TODO(jessegross): This does not do padding, which is required for flash attention
|
||||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
len := c.curCellRange.max - c.curCellRange.min + 1
|
||||||
|
mask := make([]float32, c.curBatchSize*len)
|
||||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
|
||||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
|
||||||
|
|
||||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
|
||||||
mask := make([]float32, batchSize*length)
|
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
enabled := !slices.Contains(c.opts.Except, i)
|
|
||||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
c.cells[j].pos < positions[i]-c.windowSize {
|
||||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
||||||
// has already been masked out because the sequence doesn't match.
|
|
||||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
|
||||||
mask[i] = float32(math.Inf(-1))
|
|
||||||
}
|
|
||||||
|
|
||||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.MaskDType != ml.DTypeF32 {
|
|
||||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
|
||||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
|
||||||
maskTensor = out
|
|
||||||
}
|
|
||||||
|
|
||||||
return maskTensor, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
||||||
for i, key := range c.keys {
|
for _, obj := range objs {
|
||||||
if key == nil {
|
if obj == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
kHeadDim := key.Dim(0)
|
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
||||||
numKVHeads := key.Dim(1)
|
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
||||||
rowSize := key.Stride(2)
|
|
||||||
|
|
||||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
ctx.Forward(srcView.Copy(ctx, dstView))
|
||||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
|
||||||
|
|
||||||
value := c.values[i]
|
|
||||||
var vSrcView, vDstView ml.Tensor
|
|
||||||
if c.config.PermutedV {
|
|
||||||
vHeadDim := value.Dim(1)
|
|
||||||
elemSize := value.Stride(0)
|
|
||||||
|
|
||||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
|
||||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
|
||||||
} else {
|
|
||||||
vHeadDim := value.Dim(0)
|
|
||||||
rowSize := value.Stride(2)
|
|
||||||
|
|
||||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
|
||||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.Forward(
|
|
||||||
kSrcView.Copy(ctx, kDstView),
|
|
||||||
vSrcView.Copy(ctx, vDstView),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +219,7 @@ func (c *Causal) defrag() {
|
|||||||
layers++
|
layers++
|
||||||
}
|
}
|
||||||
|
|
||||||
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
|
maxMoves := ctx.MaxTensors() / (6 * layers)
|
||||||
moves := 0
|
moves := 0
|
||||||
|
|
||||||
var pendingSrc, pendingDst, pendingLen int
|
var pendingSrc, pendingDst, pendingLen int
|
||||||
@@ -349,7 +238,8 @@ func (c *Causal) defrag() {
|
|||||||
pendingLen++
|
pendingLen++
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||||
moves++
|
moves++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -373,7 +263,8 @@ func (c *Causal) defrag() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if pendingLen > 0 {
|
if pendingLen > 0 {
|
||||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||||
|
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||||
moves++
|
moves++
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,107 +293,47 @@ func (c *Causal) defrag() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) SetLayer(layer int) {
|
func (c *Causal) SetLayer(layer int) {
|
||||||
c.curLayer = layer
|
if layer >= len(c.keys) {
|
||||||
}
|
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
||||||
|
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
||||||
type CausalOptions struct {
|
|
||||||
// Enabled controls whether the causal mask is generated for a particular index in a batch
|
|
||||||
Except []int
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCausal disables causal mask generation for a particular range of indicies in
|
|
||||||
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
|
||||||
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
|
||||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
|
||||||
c.opts = opts
|
|
||||||
if ctx != nil {
|
|
||||||
var err error
|
|
||||||
c.curMask, err = c.buildMask(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// This error should never occur because we have previously built a mask with the same shape
|
|
||||||
panic(fmt.Errorf("SetCausal: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||||
key := c.keys[c.curLayer]
|
key := c.keys[c.curLayer]
|
||||||
value := c.values[c.curLayer]
|
value := c.values[c.curLayer]
|
||||||
|
|
||||||
kHeadDim := key.Dim(0)
|
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||||
numKVHeads := key.Dim(1)
|
key.Dim(0), key.Stride(1),
|
||||||
rowSize := key.Stride(2)
|
key.Dim(1), key.Stride(2),
|
||||||
cachedSize := c.curMask.Dim(0)
|
c.curMask.Dim(0),
|
||||||
|
|
||||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
|
||||||
kHeadDim, key.Stride(1),
|
|
||||||
numKVHeads, key.Stride(2),
|
|
||||||
cachedSize,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if c.config.PermutedV {
|
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||||
vHeadDim := value.Dim(1)
|
value.Dim(0), value.Stride(1),
|
||||||
elemSize := value.Stride(0)
|
value.Dim(1), value.Stride(2),
|
||||||
|
c.curMask.Dim(0),
|
||||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
)
|
||||||
cachedSize, value.Stride(1),
|
|
||||||
vHeadDim, value.Stride(2),
|
|
||||||
numKVHeads,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
vHeadDim := value.Dim(0)
|
|
||||||
rowSize := value.Stride(2)
|
|
||||||
|
|
||||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
|
||||||
vHeadDim, value.Stride(1),
|
|
||||||
numKVHeads, value.Stride(2),
|
|
||||||
cachedSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, value, c.curMask
|
return key, value, c.curMask
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||||
kHeadDim := key.Dim(0)
|
if c.curBatchSize != key.Dim(2) {
|
||||||
vHeadDim := value.Dim(0)
|
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
||||||
numKVHeads := key.Dim(1)
|
|
||||||
batchSize := key.Dim(2)
|
|
||||||
|
|
||||||
if c.curBatchSize != batchSize {
|
|
||||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||||
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
||||||
|
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := c.keys[c.curLayer]; !ok {
|
ctx.Forward(
|
||||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
|
||||||
}
|
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
|
||||||
|
)
|
||||||
if _, ok := c.values[c.curLayer]; !ok {
|
|
||||||
if c.config.PermutedV {
|
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
|
||||||
} else {
|
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rowSize := c.keys[c.curLayer].Stride(2)
|
|
||||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
|
||||||
|
|
||||||
if c.config.PermutedV {
|
|
||||||
elemSize := c.values[c.curLayer].Stride(0)
|
|
||||||
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
|
||||||
} else {
|
|
||||||
rowSize := c.values[c.curLayer].Stride(2)
|
|
||||||
|
|
||||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||||
@@ -548,7 +379,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
kShift, err := ctx.FromIntSlice(offsets, len(offsets))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -558,13 +389,9 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
kHeadDim := key.Dim(0)
|
key = key.View(ctx, key.Stride(2)*seqRange.min,
|
||||||
numKVHeads := key.Dim(1)
|
key.Dim(0), key.Stride(1),
|
||||||
rowSize := key.Stride(2)
|
key.Dim(1), key.Stride(2),
|
||||||
|
|
||||||
key = key.View(ctx, rowSize*seqRange.min,
|
|
||||||
kHeadDim, key.Stride(1),
|
|
||||||
numKVHeads, key.Stride(2),
|
|
||||||
size,
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
@@ -270,7 +269,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||||||
context := backend.NewContext()
|
context := backend.NewContext()
|
||||||
defer context.Close()
|
defer context.Close()
|
||||||
|
|
||||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
err := cache.StartForward(context, test.pos, test.seqs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -304,17 +303,13 @@ func (b *testBackend) NewContext() ml.Context {
|
|||||||
return &testContext{}
|
return &testContext{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *testBackend) NewContextSize(int) ml.Context {
|
|
||||||
return &testContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *testBackend) SystemInfo() string {
|
func (b *testBackend) SystemInfo() string {
|
||||||
return "not implemented"
|
return "not implemented"
|
||||||
}
|
}
|
||||||
|
|
||||||
type testContext struct{}
|
type testContext struct{}
|
||||||
|
|
||||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||||
total := 0
|
total := 0
|
||||||
|
|
||||||
if len(shape) > 0 {
|
if len(shape) > 0 {
|
||||||
@@ -327,12 +322,8 @@ func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
|||||||
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
|
||||||
return c.Empty(dtype, shape...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
|
||||||
|
|
||||||
copy(t.data, s)
|
copy(t.data, s)
|
||||||
|
|
||||||
@@ -351,15 +342,11 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *testContext) Input() ml.Context { return c }
|
|
||||||
func (c *testContext) Output() ml.Context { return c }
|
|
||||||
func (c *testContext) Layer(int) ml.Context { return c }
|
|
||||||
|
|
||||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||||
|
|
||||||
func (c *testContext) Compute(...ml.Tensor) {}
|
func (c *testContext) Compute(...ml.Tensor) {}
|
||||||
|
|
||||||
func (c *testContext) MaxGraphNodes() int {
|
func (c *testContext) MaxTensors() int {
|
||||||
return 10
|
return 10
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -404,7 +391,7 @@ func (t *testTensor) Floats() []float32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
|
||||||
|
|
||||||
for i := range out.data {
|
for i := range out.data {
|
||||||
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||||
@@ -441,19 +428,11 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,7 +468,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
|||||||
|
|
||||||
context := &testContext{}
|
context := &testContext{}
|
||||||
|
|
||||||
view := context.Empty(t.dtype, s...).(*testTensor)
|
view := context.Zeros(t.dtype, s...).(*testTensor)
|
||||||
view.data = t.data[offset : offset+len(view.data)]
|
view.data = t.data[offset : offset+len(view.data)]
|
||||||
|
|
||||||
return view
|
return view
|
||||||
@@ -503,10 +482,6 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
|||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
|
||||||
panic("not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||||
panic("not implemented")
|
panic("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package kvcache
|
package kvcache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Encoder cache stores K and V tensors that are position independent
|
// Encoder cache stores K and V tensors that are position independent
|
||||||
@@ -14,9 +11,6 @@ import (
|
|||||||
//
|
//
|
||||||
// Not currently safe for multiple sequences
|
// Not currently safe for multiple sequences
|
||||||
type EncoderCache struct {
|
type EncoderCache struct {
|
||||||
// config controls mostly backend-specific optimizations
|
|
||||||
config *ml.CacheConfig
|
|
||||||
|
|
||||||
// ** current forward pass **
|
// ** current forward pass **
|
||||||
|
|
||||||
// the active layer for Get and Put
|
// the active layer for Get and Put
|
||||||
@@ -36,59 +30,36 @@ type EncoderCache struct {
|
|||||||
encoderPos int32
|
encoderPos int32
|
||||||
|
|
||||||
// ** cache data storage **
|
// ** cache data storage **
|
||||||
backend ml.Backend
|
|
||||||
ctxs map[int]ml.Context
|
cacheCtx ml.Context
|
||||||
keys, values map[int]ml.Tensor
|
keys, values []ml.Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewEncoderCache() *EncoderCache {
|
func NewEncoderCache() *EncoderCache {
|
||||||
return &EncoderCache{
|
return &EncoderCache{}
|
||||||
ctxs: make(map[int]ml.Context),
|
|
||||||
keys: make(map[int]ml.Tensor),
|
|
||||||
values: make(map[int]ml.Tensor),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||||
if c.config == nil {
|
c.cacheCtx = backend.NewContext()
|
||||||
var config ml.CacheConfig
|
|
||||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
|
||||||
config = cc.CacheConfig()
|
|
||||||
}
|
|
||||||
c.config = &config
|
|
||||||
}
|
|
||||||
|
|
||||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
|
||||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
|
||||||
}
|
|
||||||
|
|
||||||
c.backend = backend
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
|
||||||
if c.config != nil {
|
|
||||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.config = &config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) Close() {
|
func (c *EncoderCache) Close() {
|
||||||
for _, ctx := range c.ctxs {
|
c.cacheCtx.Close()
|
||||||
ctx.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||||
// We work with the most recent image
|
// The image is always in the first position
|
||||||
if len(opts.Multimodal) > 0 {
|
c.curPos = positions[0]
|
||||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *EncoderCache) SetLayer(layer int) {
|
func (c *EncoderCache) SetLayer(layer int) {
|
||||||
|
if layer >= len(c.keys) {
|
||||||
|
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
|
||||||
|
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
|
||||||
|
}
|
||||||
|
|
||||||
c.curLayer = layer
|
c.curLayer = layer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,20 +75,9 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
|||||||
c.encoderPos = c.curPos
|
c.encoderPos = c.curPos
|
||||||
c.encoderCached = true
|
c.encoderCached = true
|
||||||
|
|
||||||
if c.config.PermutedV {
|
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
||||||
}
|
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||||
|
|
||||||
if _, ok := c.ctxs[c.curLayer]; !ok {
|
|
||||||
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := c.keys[c.curLayer]; !ok {
|
|
||||||
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := c.values[c.curLayer]; !ok {
|
|
||||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(
|
ctx.Forward(
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Wrapper cache is a container for multiple types of caches,
|
// Wrapper cache is a container for multiple types of caches,
|
||||||
@@ -29,26 +28,20 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
|
||||||
for _, cache := range c.caches {
|
|
||||||
cache.SetConfig(config)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WrapperCache) Close() {
|
func (c *WrapperCache) Close() {
|
||||||
for _, cache := range c.caches {
|
for _, cache := range c.caches {
|
||||||
cache.Close()
|
cache.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
|
||||||
for i, cache := range c.caches {
|
for i, cache := range c.caches {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, positions, seqs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||||
for j := i - 1; j >= 0; j-- {
|
for j := i - 1; j >= 0; j-- {
|
||||||
for k := range opts.Positions {
|
for k := range positions {
|
||||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|||||||
2
llama/llama.cpp/src/llama-vocab.cpp
vendored
2
llama/llama.cpp/src/llama-vocab.cpp
vendored
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||||||
|
|
||||||
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
||||||
if (precompiled_charsmap_keyidx != -1) {
|
if (precompiled_charsmap_keyidx != -1) {
|
||||||
size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
|
size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
||||||
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
||||||
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
||||||
#ifdef IS_BIG_ENDIAN
|
#ifdef IS_BIG_ENDIAN
|
||||||
|
|||||||
@@ -21,6 +21,18 @@ package llama
|
|||||||
|
|
||||||
extern bool llamaProgressCallback(float progress, void *user_data);
|
extern bool llamaProgressCallback(float progress, void *user_data);
|
||||||
extern void llamaLog(int level, char* text, void* user_data);
|
extern void llamaLog(int level, char* text, void* user_data);
|
||||||
|
|
||||||
|
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
|
||||||
|
COMPILER inline get_compiler() {
|
||||||
|
#if defined(__clang__)
|
||||||
|
return COMP_CLANG;
|
||||||
|
#elif defined(__GNUC__)
|
||||||
|
return COMP_GCC;
|
||||||
|
#else
|
||||||
|
return UNKNOWN_COMPILER;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
@@ -60,6 +72,19 @@ func BackendInit() {
|
|||||||
C.llama_backend_init()
|
C.llama_backend_init()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PrintSystemInfo() string {
|
||||||
|
var compiler string
|
||||||
|
switch C.get_compiler() {
|
||||||
|
case C.COMP_UNKNOWN:
|
||||||
|
compiler = "cgo(unknown_compiler)"
|
||||||
|
case C.COMP_GCC:
|
||||||
|
compiler = "cgo(gcc)"
|
||||||
|
case C.COMP_CLANG:
|
||||||
|
compiler = "cgo(clang)"
|
||||||
|
}
|
||||||
|
return C.GoString(C.llama_print_system_info()) + compiler
|
||||||
|
}
|
||||||
|
|
||||||
func GetModelArch(modelPath string) (string, error) {
|
func GetModelArch(modelPath string) (string, error) {
|
||||||
mp := C.CString(modelPath)
|
mp := C.CString(modelPath)
|
||||||
defer C.free(unsafe.Pointer(mp))
|
defer C.free(unsafe.Pointer(mp))
|
||||||
@@ -245,20 +270,6 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
|||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadVocabFromFile(path string) (*Vocab, error) {
|
|
||||||
mp := C.CString(path)
|
|
||||||
defer C.free(unsafe.Pointer(mp))
|
|
||||||
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
|
|
||||||
if v.c == nil {
|
|
||||||
return nil, fmt.Errorf("unable to load vocab: %s", path)
|
|
||||||
}
|
|
||||||
return &v, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FreeVocab(vocab *Vocab) {
|
|
||||||
C.llama_free_vocab(vocab.c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func FreeModel(model *Model) {
|
func FreeModel(model *Model) {
|
||||||
C.llama_model_free(model.c)
|
C.llama_model_free(model.c)
|
||||||
}
|
}
|
||||||
@@ -307,10 +318,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Vocab struct {
|
|
||||||
c *C.struct_llama_vocab
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Vocab() *C.struct_llama_vocab {
|
func (m *Model) Vocab() *C.struct_llama_vocab {
|
||||||
return C.llama_model_get_vocab(m.c)
|
return C.llama_model_get_vocab(m.c)
|
||||||
}
|
}
|
||||||
@@ -687,53 +694,3 @@ func SchemaToGrammar(schema []byte) []byte {
|
|||||||
}
|
}
|
||||||
return buf[:n]
|
return buf[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
type Sampler struct {
|
|
||||||
c *C.struct_llama_sampler
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
|
|
||||||
cGrammar := C.CString(grammar)
|
|
||||||
cRoot := C.CString("root")
|
|
||||||
defer C.free(unsafe.Pointer(cGrammar))
|
|
||||||
defer C.free(unsafe.Pointer(cRoot))
|
|
||||||
|
|
||||||
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
|
|
||||||
|
|
||||||
return sampler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sampler) Accept(token int32) {
|
|
||||||
C.llama_sampler_accept(s.c, C.llama_token(token))
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenData struct {
|
|
||||||
Id int32
|
|
||||||
Logit float32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sampler) Apply(tokens []TokenData) {
|
|
||||||
tds := make([]C.struct_llama_token_data, len(tokens))
|
|
||||||
for i, token := range tokens {
|
|
||||||
tds[i] = C.struct_llama_token_data{
|
|
||||||
id: C.int32_t(token.Id),
|
|
||||||
logit: C.float(token.Logit),
|
|
||||||
p: C.float(0.0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tda := &C.llama_token_data_array{
|
|
||||||
data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
|
|
||||||
size: C.size_t(len(tokens)),
|
|
||||||
selected: C.int64_t(-1),
|
|
||||||
sorted: C.bool(false),
|
|
||||||
}
|
|
||||||
|
|
||||||
var pinner runtime.Pinner
|
|
||||||
pinner.Pin(&tds[0])
|
|
||||||
defer pinner.Unpin()
|
|
||||||
|
|
||||||
C.llama_sampler_apply(s.c, tda)
|
|
||||||
for i := range tokens {
|
|
||||||
tokens[i].Logit = float32(tds[i].logit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
69
llama/patches/0015-try-catch-backend-load.patch
Normal file
69
llama/patches/0015-try-catch-backend-load.patch
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Michael Yang <mxyng@pm.me>
|
||||||
|
Date: Tue, 11 Feb 2025 14:06:36 -0800
|
||||||
|
Subject: [PATCH] try/catch backend load
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
|
||||||
|
1 file changed, 23 insertions(+), 22 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||||
|
index 98d5e14d..1c19129a 100644
|
||||||
|
--- a/ggml/src/ggml-backend-reg.cpp
|
||||||
|
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||||
|
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
|
}
|
||||||
|
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||||
|
for (const auto & entry : dir_it) {
|
||||||
|
- if (entry.is_regular_file()) {
|
||||||
|
- std::wstring filename = entry.path().filename().wstring();
|
||||||
|
- std::wstring ext = entry.path().extension().wstring();
|
||||||
|
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
|
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||||
|
- if (!handle && !silent) {
|
||||||
|
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
- }
|
||||||
|
- if (handle) {
|
||||||
|
+ try {
|
||||||
|
+ if (entry.is_regular_file()) {
|
||||||
|
+ std::wstring filename = entry.path().filename().wstring();
|
||||||
|
+ std::wstring ext = entry.path().extension().wstring();
|
||||||
|
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
|
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||||
|
+ if (!handle) {
|
||||||
|
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
|
- if (score_fn) {
|
||||||
|
- int s = score_fn();
|
||||||
|
-#ifndef NDEBUG
|
||||||
|
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||||
|
-#endif
|
||||||
|
- if (s > best_score) {
|
||||||
|
- best_score = s;
|
||||||
|
- best_path = entry.path().wstring();
|
||||||
|
- }
|
||||||
|
- } else {
|
||||||
|
- if (!silent) {
|
||||||
|
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
- }
|
||||||
|
+ if (!score_fn) {
|
||||||
|
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int s = score_fn();
|
||||||
|
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||||
|
+ if (s > best_score) {
|
||||||
|
+ best_score = s;
|
||||||
|
+ best_path = entry.path().wstring();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
+ } catch (const std::exception & e) {
|
||||||
|
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
|
|||||||
Subject: [PATCH] use std::filesystem::path instead of wstring
|
Subject: [PATCH] use std::filesystem::path instead of wstring
|
||||||
|
|
||||||
---
|
---
|
||||||
ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++-------------------
|
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
|
||||||
1 file changed, 88 insertions(+), 111 deletions(-)
|
1 file changed, 58 insertions(+), 86 deletions(-)
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||||
index 98d5e14d..799af5f3 100644
|
index 1c19129a..c854e6bb 100644
|
||||||
--- a/ggml/src/ggml-backend-reg.cpp
|
--- a/ggml/src/ggml-backend-reg.cpp
|
||||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||||
@@ -66,26 +66,6 @@
|
@@ -66,26 +66,6 @@
|
||||||
@@ -264,55 +264,47 @@ index 98d5e14d..799af5f3 100644
|
|||||||
for (const auto & search_path : search_paths) {
|
for (const auto & search_path : search_paths) {
|
||||||
if (!fs::exists(search_path)) {
|
if (!fs::exists(search_path)) {
|
||||||
continue;
|
continue;
|
||||||
@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
|
||||||
for (const auto & entry : dir_it) {
|
for (const auto & entry : dir_it) {
|
||||||
if (entry.is_regular_file()) {
|
try {
|
||||||
- std::wstring filename = entry.path().filename().wstring();
|
if (entry.is_regular_file()) {
|
||||||
- std::wstring ext = entry.path().extension().wstring();
|
- std::wstring filename = entry.path().filename().wstring();
|
||||||
+ std::string filename = entry.path().filename().string();
|
- std::wstring ext = entry.path().extension().wstring();
|
||||||
+ std::string ext = entry.path().extension().string();
|
+ std::string filename = entry.path().filename().string();
|
||||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
+ std::string ext = entry.path().extension().string();
|
||||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
- if (!handle && !silent) {
|
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||||
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
if (!handle) {
|
||||||
+ if (!handle) {
|
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||||
+ continue;
|
continue;
|
||||||
}
|
}
|
||||||
- if (handle) {
|
|
||||||
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
- if (score_fn) {
|
if (!score_fn) {
|
||||||
- int s = score_fn();
|
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
-#ifndef NDEBUG
|
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
continue;
|
||||||
-#endif
|
}
|
||||||
- if (s > best_score) {
|
|
||||||
- best_score = s;
|
int s = score_fn();
|
||||||
- best_path = entry.path().wstring();
|
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||||
- }
|
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||||
- } else {
|
if (s > best_score) {
|
||||||
- if (!silent) {
|
best_score = s;
|
||||||
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
- best_path = entry.path().wstring();
|
||||||
- }
|
+ best_path = entry.path();
|
||||||
- }
|
}
|
||||||
+
|
|
||||||
+ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
|
||||||
+ if (!score_fn) {
|
|
||||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
|
||||||
+ continue;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ int s = score_fn();
|
|
||||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
|
||||||
+ if (s > best_score) {
|
|
||||||
+ best_score = s;
|
|
||||||
+ best_path = entry.path();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||||
|
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||||
}
|
}
|
||||||
@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
}
|
||||||
|
}
|
||||||
|
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
if (best_score == 0) {
|
if (best_score == 0) {
|
||||||
// try to load the base backend
|
// try to load the base backend
|
||||||
for (const auto & search_path : search_paths) {
|
for (const auto & search_path : search_paths) {
|
||||||
@@ -321,49 +313,3 @@ index 98d5e14d..799af5f3 100644
|
|||||||
if (fs::exists(path)) {
|
if (fs::exists(path)) {
|
||||||
return get_reg().load_backend(path, silent);
|
return get_reg().load_backend(path, silent);
|
||||||
}
|
}
|
||||||
@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
|
|
||||||
ggml_backend_load_all_from_path(nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
|
|
||||||
+ try {
|
|
||||||
+ ggml_backend_load_best(name, silent, user_search_path);
|
|
||||||
+ } catch (const std::exception & e) {
|
|
||||||
+ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
|
|
||||||
+ }
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
void ggml_backend_load_all_from_path(const char * dir_path) {
|
|
||||||
#ifdef NDEBUG
|
|
||||||
bool silent = true;
|
|
||||||
@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|
||||||
bool silent = false;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
- ggml_backend_load_best("blas", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("cann", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("hip", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("kompute", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("metal", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("rpc", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("sycl", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("vulkan", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("opencl", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("musa", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("cpu", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("blas", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("cann", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("cuda", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("hip", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("kompute", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("metal", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("rpc", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("sycl", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("vulkan", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("opencl", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("musa", silent, dir_path);
|
|
||||||
+ ggml_backend_try_load_best("cpu", silent, dir_path);
|
|
||||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
|
||||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
|
||||||
if (backend_path) {
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: jmorganca <jmorganca@gmail.com>
|
|
||||||
Date: Wed, 5 Mar 2025 17:41:07 -0800
|
|
||||||
Subject: [PATCH] fix string arr kv loading
|
|
||||||
|
|
||||||
---
|
|
||||||
ggml/include/gguf.h | 1 +
|
|
||||||
ggml/src/gguf.cpp | 7 +++++--
|
|
||||||
src/llama-vocab.cpp | 2 +-
|
|
||||||
3 files changed, 7 insertions(+), 3 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
|
|
||||||
index 79ee2020..3efb22f0 100644
|
|
||||||
--- a/ggml/include/gguf.h
|
|
||||||
+++ b/ggml/include/gguf.h
|
|
||||||
@@ -114,6 +114,7 @@ extern "C" {
|
|
||||||
// get raw pointer to the first element of the array with the given key_id
|
|
||||||
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
|
||||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
+ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
|
|
||||||
// get ith C string from array with given key_id
|
|
||||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
|
||||||
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
|
|
||||||
index ab13669c..f75b923f 100644
|
|
||||||
--- a/ggml/src/gguf.cpp
|
|
||||||
+++ b/ggml/src/gguf.cpp
|
|
||||||
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
|
||||||
|
|
||||||
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
|
||||||
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
|
||||||
return ctx->kv[key_id].data.data();
|
|
||||||
}
|
|
||||||
|
|
||||||
+size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
|
|
||||||
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
|
||||||
+ return ctx->kv[key_id].data.size();
|
|
||||||
+}
|
|
||||||
+
|
|
||||||
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
|
||||||
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
|
||||||
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
|
|
||||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
|
||||||
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
|
||||||
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
|
||||||
return ctx->kv[key_id].data.data();
|
|
||||||
}
|
|
||||||
|
|
||||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
|
||||||
index c7ff28be..7a185443 100644
|
|
||||||
--- a/src/llama-vocab.cpp
|
|
||||||
+++ b/src/llama-vocab.cpp
|
|
||||||
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
||||||
|
|
||||||
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
|
|
||||||
if (precompiled_charsmap_keyidx != -1) {
|
|
||||||
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
|
|
||||||
+ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
|
|
||||||
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
|
|
||||||
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
|
|
||||||
#ifdef IS_BIG_ENDIAN
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Michael Yang <mxyng@pm.me>
|
|
||||||
Date: Sun, 9 Mar 2025 14:44:16 -0700
|
|
||||||
Subject: [PATCH] ollama debug tensor
|
|
||||||
|
|
||||||
---
|
|
||||||
ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++
|
|
||||||
1 file changed, 6 insertions(+)
|
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
|
|
||||||
index 2f606d82..ec60e8fc 100644
|
|
||||||
--- a/ggml/src/ggml-cpu/ggml-cpu.c
|
|
||||||
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
|
|
||||||
@@ -11,6 +11,8 @@
|
|
||||||
#include "ggml-threading.h"
|
|
||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
+#include "ollama-debug.h"
|
|
||||||
+
|
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
|
||||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
|
||||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
|
||||||
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|
||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
|
||||||
|
|
||||||
+#ifdef OLLAMA_DEBUG
|
|
||||||
+ ollama_debug(node, true);
|
|
||||||
+#endif
|
|
||||||
+
|
|
||||||
if (state->ith == 0 && cplan->abort_callback &&
|
|
||||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
|
||||||
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
|
||||||
22
llama/sampling_ext.cpp
vendored
22
llama/sampling_ext.cpp
vendored
@@ -2,9 +2,6 @@
|
|||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "sampling_ext.h"
|
#include "sampling_ext.h"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
|
||||||
#include "llama-model.h"
|
|
||||||
#include "llama-model-loader.h"
|
|
||||||
|
|
||||||
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
|
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
|
||||||
try {
|
try {
|
||||||
@@ -67,22 +64,3 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
|
|
||||||
llama_vocab * vocab = new llama_vocab();
|
|
||||||
try {
|
|
||||||
const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
|
|
||||||
std::vector<std::string> splits = {};
|
|
||||||
llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
|
|
||||||
vocab->load(ml, kv);
|
|
||||||
} catch (const std::exception & err) {
|
|
||||||
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
return vocab;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_free_vocab(struct llama_vocab * vocab) {
|
|
||||||
delete vocab;
|
|
||||||
}
|
|
||||||
|
|||||||
3
llama/sampling_ext.h
vendored
3
llama/sampling_ext.h
vendored
@@ -35,9 +35,6 @@ extern "C"
|
|||||||
|
|
||||||
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
|
||||||
|
|
||||||
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
|
|
||||||
void llama_free_vocab(struct llama_vocab * vocab);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -115,9 +115,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
// multimodal models require at least 2048 context
|
// multimodal models require at least 2048 context
|
||||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||||
}
|
}
|
||||||
if projectorWeights == 0 && projectorGraph == 0 {
|
|
||||||
projectorWeights, projectorGraph = f.VisionGraphSize()
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
// add one layer worth of memory as a buffer
|
// add one layer worth of memory as a buffer
|
||||||
|
|||||||
247
llm/server.go
247
llm/server.go
@@ -30,7 +30,6 @@ import (
|
|||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/llama"
|
"github.com/ollama/ollama/llama"
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type LlamaServer interface {
|
type LlamaServer interface {
|
||||||
@@ -55,15 +54,8 @@ type llmServer struct {
|
|||||||
options api.Options
|
options api.Options
|
||||||
numParallel int
|
numParallel int
|
||||||
modelPath string
|
modelPath string
|
||||||
|
modelLock sync.Mutex // Temporary until we switch fully to Go server
|
||||||
// llamaModel is an instance of the cgo llama.cpp model definition
|
model *llama.Model // If non-nil, the runner is a new Go server
|
||||||
// nil if this server is running the new engine
|
|
||||||
llamaModel *llama.Model
|
|
||||||
llamaModelLock sync.Mutex
|
|
||||||
|
|
||||||
// textProcessor handles text encoding/decoding for the model in the Ollama engine
|
|
||||||
// nil if this server is running the llama.cpp based engine
|
|
||||||
textProcessor model.TextProcessor
|
|
||||||
|
|
||||||
estimate MemoryEstimate
|
estimate MemoryEstimate
|
||||||
totalLayers uint64
|
totalLayers uint64
|
||||||
@@ -97,7 +89,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
|||||||
|
|
||||||
// NewLlamaServer will run a server for the given GPUs
|
// NewLlamaServer will run a server for the given GPUs
|
||||||
// The gpu list must be a single family.
|
// The gpu list must be a single family.
|
||||||
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||||
systemInfo := discover.GetSystemInfo()
|
systemInfo := discover.GetSystemInfo()
|
||||||
systemTotalMemory := systemInfo.System.TotalMemory
|
systemTotalMemory := systemInfo.System.TotalMemory
|
||||||
systemFreeMemory := systemInfo.System.FreeMemory
|
systemFreeMemory := systemInfo.System.FreeMemory
|
||||||
@@ -138,7 +130,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
slog.Info("offload", "", estimate)
|
slog.Info("offload", "", estimate)
|
||||||
|
|
||||||
params := []string{
|
params := []string{
|
||||||
"--model", modelPath,
|
"--model", model,
|
||||||
"--ctx-size", strconv.Itoa(opts.NumCtx),
|
"--ctx-size", strconv.Itoa(opts.NumCtx),
|
||||||
"--batch-size", strconv.Itoa(opts.NumBatch),
|
"--batch-size", strconv.Itoa(opts.NumBatch),
|
||||||
}
|
}
|
||||||
@@ -161,6 +153,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(projectors) > 0 {
|
||||||
|
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||||
|
params = append(params, "--mmproj", projectors[0])
|
||||||
|
}
|
||||||
|
|
||||||
defaultThreads := systemInfo.GetOptimalThreadCount()
|
defaultThreads := systemInfo.GetOptimalThreadCount()
|
||||||
if opts.NumThread > 0 {
|
if opts.NumThread > 0 {
|
||||||
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
|
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
|
||||||
@@ -260,34 +257,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Debug("compatible gpu libraries", "compatible", compatible)
|
slog.Debug("compatible gpu libraries", "compatible", compatible)
|
||||||
exe, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
|
||||||
exe = eval
|
|
||||||
}
|
|
||||||
|
|
||||||
var llamaModel *llama.Model
|
|
||||||
var textProcessor model.TextProcessor
|
|
||||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
|
||||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
|
||||||
if err != nil {
|
|
||||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
|
||||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if textProcessor == nil {
|
|
||||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(projectors) > 0 && llamaModel != nil {
|
|
||||||
params = append(params, "--mmproj", projectors[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||||
@@ -306,9 +275,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||||
}
|
}
|
||||||
finalParams := []string{"runner"}
|
finalParams := []string{"runner"}
|
||||||
if textProcessor != nil {
|
if envconfig.NewEngine() {
|
||||||
// New engine
|
|
||||||
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
|
|
||||||
finalParams = append(finalParams, "--ollama-engine")
|
finalParams = append(finalParams, "--ollama-engine")
|
||||||
}
|
}
|
||||||
finalParams = append(finalParams, params...)
|
finalParams = append(finalParams, params...)
|
||||||
@@ -348,20 +315,28 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
// finally, add the root library path
|
// finally, add the root library path
|
||||||
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
|
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
|
||||||
|
|
||||||
|
exe, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||||
|
exe = eval
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
|
||||||
s := &llmServer{
|
s := &llmServer{
|
||||||
port: port,
|
port: port,
|
||||||
cmd: exec.Command(exe, finalParams...),
|
cmd: exec.Command(exe, finalParams...),
|
||||||
status: NewStatusWriter(os.Stderr),
|
status: NewStatusWriter(os.Stderr),
|
||||||
options: opts,
|
options: opts,
|
||||||
modelPath: modelPath,
|
modelPath: model,
|
||||||
llamaModel: llamaModel,
|
estimate: estimate,
|
||||||
textProcessor: textProcessor,
|
numParallel: numParallel,
|
||||||
estimate: estimate,
|
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||||
numParallel: numParallel,
|
totalLayers: f.KV().BlockCount() + 1,
|
||||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
gpus: gpus,
|
||||||
totalLayers: f.KV().BlockCount() + 1,
|
done: make(chan error, 1),
|
||||||
gpus: gpus,
|
|
||||||
done: make(chan error, 1),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.cmd.Env = os.Environ()
|
s.cmd.Env = os.Environ()
|
||||||
@@ -430,9 +405,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
}
|
}
|
||||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||||
if len(compatible) == 0 {
|
if len(compatible) == 0 {
|
||||||
if llamaModel != nil {
|
|
||||||
llama.FreeModel(llamaModel)
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -961,25 +933,64 @@ type TokenizeResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
s.llamaModelLock.Lock()
|
s.modelLock.Lock()
|
||||||
defer s.llamaModelLock.Unlock()
|
defer s.modelLock.Unlock()
|
||||||
|
if s.model != nil {
|
||||||
|
return s.model.Tokenize(content, false, true)
|
||||||
|
}
|
||||||
|
|
||||||
if s.llamaModel != nil {
|
// Make sure the server is ready
|
||||||
return s.llamaModel.Tokenize(content, false, true)
|
status, err := s.getServerStatus(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||||
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
if s.textProcessor != nil {
|
|
||||||
tokens, err := s.textProcessor.Encode(content, false)
|
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||||
}
|
|
||||||
toks := make([]int, len(tokens))
|
|
||||||
for i, t := range tokens {
|
|
||||||
toks[i] = int(t)
|
|
||||||
}
|
|
||||||
return toks, nil
|
|
||||||
}
|
}
|
||||||
// not reached
|
|
||||||
return nil, fmt.Errorf("no tokenizer configured")
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encode request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("do encode request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
if s.model == nil {
|
||||||
|
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||||
|
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.model = m
|
||||||
|
}
|
||||||
|
return s.model.Tokenize(content, false, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read encode request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
log.Printf("llm encode error: %s", body)
|
||||||
|
return nil, fmt.Errorf("%s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var encoded TokenizeResponse
|
||||||
|
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||||
|
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return encoded.Tokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type DetokenizeRequest struct {
|
type DetokenizeRequest struct {
|
||||||
@@ -991,38 +1002,80 @@ type DetokenizeResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
s.llamaModelLock.Lock()
|
s.modelLock.Lock()
|
||||||
defer s.llamaModelLock.Unlock()
|
defer s.modelLock.Unlock()
|
||||||
|
if s.model != nil {
|
||||||
if s.llamaModel != nil {
|
|
||||||
var resp string
|
var resp string
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
resp += s.llamaModel.TokenToPiece(token)
|
resp += s.model.TokenToPiece(token)
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
if s.textProcessor != nil {
|
// Make sure the server is ready
|
||||||
toks := make([]int32, len(tokens))
|
status, err := s.getServerStatus(ctx)
|
||||||
for i, t := range tokens {
|
if err != nil {
|
||||||
toks[i] = int32(t)
|
return "", err
|
||||||
}
|
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||||
content, err := s.textProcessor.Decode(toks)
|
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
// not reached
|
|
||||||
return "", fmt.Errorf("no tokenizer configured")
|
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("decode request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("do decode request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
if s.model == nil {
|
||||||
|
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||||
|
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
s.model = m
|
||||||
|
}
|
||||||
|
var resp string
|
||||||
|
for _, token := range tokens {
|
||||||
|
resp += s.model.TokenToPiece(token)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read decode request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode >= 400 {
|
||||||
|
log.Printf("llm decode error: %s", body)
|
||||||
|
return "", fmt.Errorf("%s", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decoded DetokenizeResponse
|
||||||
|
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||||
|
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decoded.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Close() error {
|
func (s *llmServer) Close() error {
|
||||||
s.llamaModelLock.Lock()
|
s.modelLock.Lock()
|
||||||
if s.llamaModel != nil {
|
if s.model != nil {
|
||||||
llama.FreeModel(s.llamaModel)
|
llama.FreeModel(s.model)
|
||||||
s.llamaModel = nil
|
s.model = nil
|
||||||
}
|
}
|
||||||
s.llamaModelLock.Unlock()
|
s.modelLock.Unlock()
|
||||||
|
|
||||||
if s.cmd != nil {
|
if s.cmd != nil {
|
||||||
slog.Debug("stopping llama server")
|
slog.Debug("stopping llama server")
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
package logging
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
const LevelTrace slog.Level = slog.LevelDebug - 4
|
|
||||||
|
|
||||||
type Logger struct {
|
|
||||||
logger *slog.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewLogger() *Logger {
|
|
||||||
handler := slog.NewTextHandler(os.Stdout, nil)
|
|
||||||
return &Logger{
|
|
||||||
logger: slog.New(handler),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Trace(msg string, args ...any) {
|
|
||||||
l.logger.Log(context.Background(), LevelTrace, msg, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Debug(msg string, args ...any) {
|
|
||||||
l.logger.Debug(msg, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Info(msg string, args ...any) {
|
|
||||||
l.logger.Info(msg, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Warn(msg string, args ...any) {
|
|
||||||
l.logger.Warn(msg, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Error(msg string, args ...any) {
|
|
||||||
l.logger.Error(msg, args...)
|
|
||||||
}
|
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -19,43 +18,13 @@ type Config interface {
|
|||||||
|
|
||||||
Strings(string, ...[]string) []string
|
Strings(string, ...[]string) []string
|
||||||
Uints(string, ...[]uint32) []uint32
|
Uints(string, ...[]uint32) []uint32
|
||||||
Floats(string, ...[]float32) []float32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend interface {
|
type Backend interface {
|
||||||
Config() Config
|
Config() Config
|
||||||
Get(name string) Tensor
|
Get(name string) Tensor
|
||||||
NewContext() Context
|
NewContext() Context
|
||||||
NewContextSize(size int) Context
|
SystemInfo() string
|
||||||
}
|
|
||||||
|
|
||||||
// BackendCacheConfig should be implemented by backends that need special output
|
|
||||||
// from the cache to meet specific requirements. It is frequently implemented in
|
|
||||||
// conjunction with ScaledDotProductAttention.
|
|
||||||
type BackendCacheConfig interface {
|
|
||||||
CacheConfig() CacheConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
|
||||||
// the output the cache to work better with specific kernels.
|
|
||||||
type CacheConfig struct {
|
|
||||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
|
||||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
|
||||||
// cache itself will also be increased to a multiple of this size if needed.
|
|
||||||
CachePadding int
|
|
||||||
|
|
||||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
|
||||||
// and return the permuted version via Get. This uses the cache copy operation
|
|
||||||
// to avoid a Contiguous call on the permuted tensor.
|
|
||||||
PermutedV bool
|
|
||||||
|
|
||||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
|
||||||
// default to DTypeF32.
|
|
||||||
MaskDType DType
|
|
||||||
|
|
||||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
|
||||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
|
||||||
MaskBatchPadding int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackendParams controls how the backend loads and executes models
|
// BackendParams controls how the backend loads and executes models
|
||||||
@@ -71,9 +40,6 @@ type BackendParams struct {
|
|||||||
|
|
||||||
// TensorSplit is the fraction of the model to offload to each GPU
|
// TensorSplit is the fraction of the model to offload to each GPU
|
||||||
TensorSplit []float32
|
TensorSplit []float32
|
||||||
|
|
||||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
|
||||||
FlashAttention bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||||
@@ -95,24 +61,14 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
Empty(dtype DType, shape ...int) Tensor
|
|
||||||
Zeros(dtype DType, shape ...int) Tensor
|
Zeros(dtype DType, shape ...int) Tensor
|
||||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||||
|
|
||||||
Forward(...Tensor) Context
|
Forward(...Tensor) Context
|
||||||
Compute(...Tensor)
|
Compute(...Tensor)
|
||||||
MaxGraphNodes() int
|
MaxTensors() int
|
||||||
Close()
|
Close()
|
||||||
|
|
||||||
// Input returns a context appropriate for creating input tensors
|
|
||||||
Input() Context
|
|
||||||
|
|
||||||
// Output returns a context appropriate for creating output tensors
|
|
||||||
Output() Context
|
|
||||||
|
|
||||||
// Layer returns a context appropriate for creating intermediate tensors
|
|
||||||
Layer(int) Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tensor interface {
|
type Tensor interface {
|
||||||
@@ -135,10 +91,8 @@ type Tensor interface {
|
|||||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
|
||||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
|
||||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||||
|
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
|
||||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
|
||||||
|
|
||||||
Tanh(ctx Context) Tensor
|
Tanh(ctx Context) Tensor
|
||||||
GELU(ctx Context) Tensor
|
GELU(ctx Context) Tensor
|
||||||
@@ -148,7 +102,6 @@ type Tensor interface {
|
|||||||
View(ctx Context, offset int, shape ...int) Tensor
|
View(ctx Context, offset int, shape ...int) Tensor
|
||||||
Permute(ctx Context, shape ...int) Tensor
|
Permute(ctx Context, shape ...int) Tensor
|
||||||
Contiguous(ctx Context) Tensor
|
Contiguous(ctx Context) Tensor
|
||||||
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
|
|
||||||
|
|
||||||
Pad(ctx Context, shape ...int) Tensor
|
Pad(ctx Context, shape ...int) Tensor
|
||||||
Unpad(ctx Context, shape ...int) Tensor
|
Unpad(ctx Context, shape ...int) Tensor
|
||||||
@@ -163,10 +116,6 @@ type Tensor interface {
|
|||||||
// operation equivalent to following code on a tensor named
|
// operation equivalent to following code on a tensor named
|
||||||
// query:
|
// query:
|
||||||
//
|
//
|
||||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
//
|
|
||||||
// kq := key.MulmatFullPrec(ctx, query)
|
// kq := key.MulmatFullPrec(ctx, query)
|
||||||
//
|
//
|
||||||
// kq = kq.Scale(ctx, scale)
|
// kq = kq.Scale(ctx, scale)
|
||||||
@@ -220,8 +169,8 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
|||||||
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
|
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
|
||||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||||
})
|
})
|
||||||
case DTypeF16, DTypeQ80, DTypeQ40:
|
case DTypeF16:
|
||||||
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
|
||||||
f32 = t.Copy(ctx, f32)
|
f32 = t.Copy(ctx, f32)
|
||||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||||
@@ -246,17 +195,16 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
slices.Reverse(shape)
|
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var f func([]int, int)
|
var f func([]int, int)
|
||||||
f = func(dims []int, stride int) {
|
f = func(dims []int, stride int) {
|
||||||
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
|
||||||
sb.WriteString("[")
|
fmt.Fprint(&sb, "[")
|
||||||
defer func() { sb.WriteString("]") }()
|
defer func() { fmt.Fprint(&sb, "]") }()
|
||||||
for i := 0; i < dims[0]; i++ {
|
for i := 0; i < dims[0]; i++ {
|
||||||
if i >= items && i < dims[0]-items {
|
if i >= items && i < dims[0]-items {
|
||||||
sb.WriteString("..., ")
|
fmt.Fprint(&sb, "..., ")
|
||||||
// skip to next printable element
|
// skip to next printable element
|
||||||
skip := dims[0] - 2*items
|
skip := dims[0] - 2*items
|
||||||
if len(dims) > 1 {
|
if len(dims) > 1 {
|
||||||
@@ -271,14 +219,9 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
|
|||||||
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
text := fn(s[stride+i])
|
fmt.Fprint(&sb, fn(s[stride+i]))
|
||||||
if len(text) > 0 && text[0] != '-' {
|
|
||||||
sb.WriteString(" ")
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(text)
|
|
||||||
if i < dims[0]-1 {
|
if i < dims[0]-1 {
|
||||||
sb.WriteString(", ")
|
fmt.Fprint(&sb, ", ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -294,7 +237,5 @@ const (
|
|||||||
DTypeOther DType = iota
|
DTypeOther DType = iota
|
||||||
DTypeF32
|
DTypeF32
|
||||||
DTypeF16
|
DTypeF16
|
||||||
DTypeQ80
|
|
||||||
DTypeQ40
|
|
||||||
DTypeI32
|
DTypeI32
|
||||||
)
|
)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
1
ml/backend/ggml/ggml/include/gguf.h
vendored
1
ml/backend/ggml/ggml/include/gguf.h
vendored
@@ -114,7 +114,6 @@ extern "C" {
|
|||||||
// get raw pointer to the first element of the array with the given key_id
|
// get raw pointer to the first element of the array with the given key_id
|
||||||
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
|
||||||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
|
||||||
GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
|
|
||||||
|
|
||||||
// get ith C string from array with given key_id
|
// get ith C string from array with given key_id
|
||||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
|
||||||
|
|||||||
11
ml/backend/ggml/ggml/include/ollama-debug.h
vendored
11
ml/backend/ggml/ggml/include/ollama-debug.h
vendored
@@ -1,11 +0,0 @@
|
|||||||
#include "ggml.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void ollama_debug(const struct ggml_tensor *tensor, bool verbose);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
74
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
74
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -484,29 +484,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||||||
}
|
}
|
||||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||||
for (const auto & entry : dir_it) {
|
for (const auto & entry : dir_it) {
|
||||||
if (entry.is_regular_file()) {
|
try {
|
||||||
std::string filename = entry.path().filename().string();
|
if (entry.is_regular_file()) {
|
||||||
std::string ext = entry.path().extension().string();
|
std::string filename = entry.path().filename().string();
|
||||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
std::string ext = entry.path().extension().string();
|
||||||
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
if (!handle) {
|
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
if (!handle) {
|
||||||
continue;
|
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||||
}
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
if (!score_fn) {
|
if (!score_fn) {
|
||||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int s = score_fn();
|
int s = score_fn();
|
||||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||||
if (s > best_score) {
|
if (s > best_score) {
|
||||||
best_score = s;
|
best_score = s;
|
||||||
best_path = entry.path();
|
best_path = entry.path();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,14 +533,6 @@ void ggml_backend_load_all() {
|
|||||||
ggml_backend_load_all_from_path(nullptr);
|
ggml_backend_load_all_from_path(nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
|
|
||||||
try {
|
|
||||||
ggml_backend_load_best(name, silent, user_search_path);
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_backend_load_all_from_path(const char * dir_path) {
|
void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||||
#ifdef NDEBUG
|
#ifdef NDEBUG
|
||||||
bool silent = true;
|
bool silent = true;
|
||||||
@@ -544,18 +540,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||||||
bool silent = false;
|
bool silent = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ggml_backend_try_load_best("blas", silent, dir_path);
|
ggml_backend_load_best("blas", silent, dir_path);
|
||||||
ggml_backend_try_load_best("cann", silent, dir_path);
|
ggml_backend_load_best("cann", silent, dir_path);
|
||||||
ggml_backend_try_load_best("cuda", silent, dir_path);
|
ggml_backend_load_best("cuda", silent, dir_path);
|
||||||
ggml_backend_try_load_best("hip", silent, dir_path);
|
ggml_backend_load_best("hip", silent, dir_path);
|
||||||
ggml_backend_try_load_best("kompute", silent, dir_path);
|
ggml_backend_load_best("kompute", silent, dir_path);
|
||||||
ggml_backend_try_load_best("metal", silent, dir_path);
|
ggml_backend_load_best("metal", silent, dir_path);
|
||||||
ggml_backend_try_load_best("rpc", silent, dir_path);
|
ggml_backend_load_best("rpc", silent, dir_path);
|
||||||
ggml_backend_try_load_best("sycl", silent, dir_path);
|
ggml_backend_load_best("sycl", silent, dir_path);
|
||||||
ggml_backend_try_load_best("vulkan", silent, dir_path);
|
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||||
ggml_backend_try_load_best("opencl", silent, dir_path);
|
ggml_backend_load_best("opencl", silent, dir_path);
|
||||||
ggml_backend_try_load_best("musa", silent, dir_path);
|
ggml_backend_load_best("musa", silent, dir_path);
|
||||||
ggml_backend_try_load_best("cpu", silent, dir_path);
|
ggml_backend_load_best("cpu", silent, dir_path);
|
||||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
||||||
if (backend_path) {
|
if (backend_path) {
|
||||||
|
|||||||
@@ -1,6 +0,0 @@
|
|||||||
//go:build debug
|
|
||||||
|
|
||||||
package cpu
|
|
||||||
|
|
||||||
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
|
|
||||||
import "C"
|
|
||||||
6
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
6
ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
vendored
@@ -11,8 +11,6 @@
|
|||||||
#include "ggml-threading.h"
|
#include "ggml-threading.h"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include "ollama-debug.h"
|
|
||||||
|
|
||||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||||
@@ -14105,10 +14103,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
#ifdef OLLAMA_DEBUG
|
|
||||||
ollama_debug(node, true);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (state->ith == 0 && cplan->abort_callback &&
|
if (state->ith == 0 && cplan->abort_callback &&
|
||||||
cplan->abort_callback(cplan->abort_callback_data)) {
|
cplan->abort_callback(cplan->abort_callback_data)) {
|
||||||
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
|
||||||
|
|||||||
@@ -7,20 +7,6 @@ package ggml
|
|||||||
// #include <stdlib.h>
|
// #include <stdlib.h>
|
||||||
// #include "ggml-backend.h"
|
// #include "ggml-backend.h"
|
||||||
// extern void sink(int level, char *text, void *user_data);
|
// extern void sink(int level, char *text, void *user_data);
|
||||||
// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); }
|
|
||||||
// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; }
|
|
||||||
/*
|
|
||||||
typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER;
|
|
||||||
static COMPILER compiler_name(void) {
|
|
||||||
#if defined(__clang__)
|
|
||||||
return COMPILER_CLANG;
|
|
||||||
#elif defined(__GNUC__)
|
|
||||||
return COMPILER_GNUC;
|
|
||||||
#else
|
|
||||||
return COMPILER_UNKNOWN;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -30,7 +16,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -105,43 +90,4 @@ var OnceLoad = sync.OnceFunc(func() {
|
|||||||
visited[abspath] = struct{}{}
|
visited[abspath] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("system", "", system{})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
type system struct{}
|
|
||||||
|
|
||||||
func (system) LogValue() slog.Value {
|
|
||||||
var attrs []slog.Attr
|
|
||||||
names := make(map[string]int)
|
|
||||||
for i := range C.ggml_backend_dev_count() {
|
|
||||||
r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i))
|
|
||||||
|
|
||||||
func() {
|
|
||||||
fName := C.CString("ggml_backend_get_features")
|
|
||||||
defer C.free(unsafe.Pointer(fName))
|
|
||||||
|
|
||||||
if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil {
|
|
||||||
var features []any
|
|
||||||
for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) {
|
|
||||||
features = append(features, C.GoString(f.name), C.GoString(f.value))
|
|
||||||
}
|
|
||||||
|
|
||||||
name := C.GoString(C.ggml_backend_reg_name(r))
|
|
||||||
attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...))
|
|
||||||
names[name] += 1
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
switch C.compiler_name() {
|
|
||||||
case C.COMPILER_CLANG:
|
|
||||||
attrs = append(attrs, slog.String("compiler", "cgo(clang)"))
|
|
||||||
case C.COMPILER_GNUC:
|
|
||||||
attrs = append(attrs, slog.String("compiler", "cgo(gcc)"))
|
|
||||||
default:
|
|
||||||
attrs = append(attrs, slog.String("compiler", "cgo(unknown)"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return slog.GroupValue(attrs...)
|
|
||||||
}
|
|
||||||
|
|||||||
7
ml/backend/ggml/ggml/src/gguf.cpp
vendored
7
ml/backend/ggml/ggml/src/gguf.cpp
vendored
@@ -777,14 +777,10 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
|
|||||||
|
|
||||||
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
|
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
||||||
return ctx->kv[key_id].data.data();
|
return ctx->kv[key_id].data.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
|
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
|
||||||
return ctx->kv[key_id].data.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
|
||||||
@@ -878,6 +874,7 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
|
|||||||
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
|
||||||
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
|
||||||
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
|
||||||
|
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
|
||||||
return ctx->kv[key_id].data.data();
|
return ctx->kv[key_id].data.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
116
ml/backend/ggml/ggml/src/ollama-debug.c
vendored
116
ml/backend/ggml/ggml/src/ollama-debug.c
vendored
@@ -1,116 +0,0 @@
|
|||||||
#include <string.h>
|
|
||||||
#include <inttypes.h>
|
|
||||||
|
|
||||||
#include "ollama-debug.h"
|
|
||||||
|
|
||||||
static int mul(int64_t *dims, int ndims) {
|
|
||||||
int result = 1;
|
|
||||||
for (int i = 0; i < ndims; i++) {
|
|
||||||
result *= dims[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void repeat(char c, int n) {
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
fprintf(stderr, "%c", c);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_tensor(const void *tensor, void (*cb)(const void *, int),
|
|
||||||
int shape,
|
|
||||||
int64_t *dims, int ndims, int stride,
|
|
||||||
int nitems, int pad) {
|
|
||||||
fprintf(stderr, "[");
|
|
||||||
for (int i = 0; i < dims[0]; i++) {
|
|
||||||
if (i >= nitems && i < dims[0] - nitems) {
|
|
||||||
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems);
|
|
||||||
int skip = dims[0] - 2 * nitems;
|
|
||||||
if (ndims > 1) {
|
|
||||||
stride += mul(dims + 1, ndims - 1) * skip;
|
|
||||||
repeat('\n', ndims - 1);
|
|
||||||
repeat(' ', shape - ndims + 1 + pad);
|
|
||||||
}
|
|
||||||
i += skip - 1;
|
|
||||||
} else if (ndims > 1) {
|
|
||||||
print_tensor(tensor, cb, shape, dims + 1, ndims - 1, stride,
|
|
||||||
nitems, pad);
|
|
||||||
stride += mul(dims + 1, ndims - 1);
|
|
||||||
if (i < dims[0] - 1) {
|
|
||||||
fprintf(stderr, ", ");
|
|
||||||
repeat('\n', ndims - 1);
|
|
||||||
repeat(' ', shape - ndims + 1 + pad);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cb(tensor, stride + i);
|
|
||||||
if (i < dims[0] - 1) {
|
|
||||||
fprintf(stderr, ", ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fprintf(stderr, "]");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_tensor_f16(const void *tensor, int i) {
|
|
||||||
float value = ggml_fp16_to_fp32(((const ggml_fp16_t *)tensor)[i]);
|
|
||||||
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_tensor_f32(const void *tensor, int i) {
|
|
||||||
float value = ((const float *)tensor)[i];
|
|
||||||
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print_tensor_i32(const void *tensor, int i) {
|
|
||||||
int32_t value = ((const int32_t *)tensor)[i];
|
|
||||||
fprintf(stderr, "%s%d", value < 0 ? "" : " ", value);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
|
|
||||||
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name,
|
|
||||||
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
|
|
||||||
tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
||||||
|
|
||||||
if (!verbose) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < indent; i++) {
|
|
||||||
fprintf(stderr, " ");
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (tensor->type) {
|
|
||||||
case GGML_TYPE_F16:
|
|
||||||
print_tensor(ggml_get_data(tensor), print_tensor_f16, ggml_n_dims(tensor),
|
|
||||||
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_F32:
|
|
||||||
print_tensor(ggml_get_data(tensor), print_tensor_f32, ggml_n_dims(tensor),
|
|
||||||
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
|
|
||||||
break;
|
|
||||||
case GGML_TYPE_I32:
|
|
||||||
print_tensor(ggml_get_data(tensor), print_tensor_i32, ggml_n_dims(tensor),
|
|
||||||
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
fprintf(stderr, "<unsupported type>\n");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
void ollama_debug(const struct ggml_tensor *tensor, bool verbose) {
|
|
||||||
ollama_debug_tensor(tensor, verbose, ">>> ", 4);
|
|
||||||
|
|
||||||
for (int i = 0; i < GGML_MAX_SRC && tensor->src[i] != NULL; ++i) {
|
|
||||||
char src[8];
|
|
||||||
const int n = snprintf(src, sizeof(src), " src%d ", i);
|
|
||||||
if (n >= sizeof(src)) {
|
|
||||||
src[sizeof(src) - 1] = '\0';
|
|
||||||
}
|
|
||||||
|
|
||||||
ollama_debug_tensor(tensor->src[i], verbose, src, 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build !debug
|
|
||||||
|
|
||||||
package ggml
|
|
||||||
|
|
||||||
func Threads(n int) int {
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build debug
|
|
||||||
|
|
||||||
package ggml
|
|
||||||
|
|
||||||
func Threads(_ int) int {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
@@ -3,7 +3,6 @@ package nn
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -12,50 +11,40 @@ import (
|
|||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - ctx: Context for tensor operations
|
// - ctx: Context for tensor operations
|
||||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
||||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
||||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
||||||
|
// - mask: Optional attention mask that is added to the attention score. If
|
||||||
|
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
||||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
//
|
//
|
||||||
// Attention output with shape [d_v, heads, seq_len_q]
|
// Attention output with shape [d_v, heads, seq_len_q]
|
||||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||||
if key != nil && value != nil {
|
if query.Dim(0) != key.Dim(0) {
|
||||||
if query.Dim(0) != key.Dim(0) {
|
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.Dim(1) != value.Dim(1) {
|
|
||||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if key.Dim(2) != value.Dim(2) {
|
|
||||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
|
||||||
}
|
|
||||||
|
|
||||||
if cache != nil {
|
|
||||||
cache.Put(ctx, key, value)
|
|
||||||
}
|
|
||||||
} else if cache == nil {
|
|
||||||
panic("key & value tensors must be provided if cache is nil")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var mask ml.Tensor
|
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
||||||
if cache != nil {
|
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
||||||
key, value, mask = cache.Get(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
if key.Dim(1) != value.Dim(0) {
|
||||||
// will do any expected backend-specific transformations for us
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
||||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
}
|
||||||
|
|
||||||
|
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
||||||
|
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if key.Dim(2) != value.Dim(2) {
|
||||||
|
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||||
} else {
|
} else {
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
kq := key.MulmatFullPrec(ctx, query)
|
kq := key.MulmatFullPrec(ctx, query)
|
||||||
|
|
||||||
kq = kq.Scale(ctx, scale)
|
kq = kq.Scale(ctx, scale)
|
||||||
|
|||||||
@@ -1,37 +0,0 @@
|
|||||||
package input
|
|
||||||
|
|
||||||
// Input represents one token in the input stream
|
|
||||||
type Input struct {
|
|
||||||
// Token is a single element of text.
|
|
||||||
Token int32
|
|
||||||
|
|
||||||
// Multimodal is opaque data representing a non-text
|
|
||||||
// element such as an image (or part of one if the image
|
|
||||||
// can be processed in pieces). It may be either together
|
|
||||||
// with Token or on its own.
|
|
||||||
Multimodal any
|
|
||||||
|
|
||||||
// MultimodalHash is a unique representation of the data
|
|
||||||
// stored in Multimodal, used for caching and comparing
|
|
||||||
// equality.
|
|
||||||
MultimodalHash uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// MultimodalIndex is a multimodal element (such as an image)
|
|
||||||
// together with an index into the slice of Inputs with the
|
|
||||||
// corresponding token. Note that the index is not the same
|
|
||||||
// as the position - to find that use the index with the
|
|
||||||
// Positions slice.
|
|
||||||
type MultimodalIndex struct {
|
|
||||||
Index int
|
|
||||||
Multimodal any
|
|
||||||
}
|
|
||||||
|
|
||||||
// Options contains the inputs for a model forward pass
|
|
||||||
type Options struct {
|
|
||||||
Inputs []int32
|
|
||||||
Multimodal []MultimodalIndex
|
|
||||||
Positions []int32
|
|
||||||
Sequences []int
|
|
||||||
Outputs []int32
|
|
||||||
}
|
|
||||||
@@ -3,6 +3,7 @@ package model
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"image"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -15,52 +16,23 @@ import (
|
|||||||
_ "golang.org/x/image/tiff"
|
_ "golang.org/x/image/tiff"
|
||||||
_ "golang.org/x/image/webp"
|
_ "golang.org/x/image/webp"
|
||||||
|
|
||||||
fs "github.com/ollama/ollama/fs/ggml"
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
_ "github.com/ollama/ollama/ml/backend"
|
_ "github.com/ollama/ollama/ml/backend"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
// Options contains the inputs for a model forward pass
|
||||||
|
type Options struct {
|
||||||
|
Inputs []int32
|
||||||
|
Positions []int32
|
||||||
|
Sequences []int
|
||||||
|
Outputs []int32
|
||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
Images []image.Image
|
||||||
type Model interface {
|
|
||||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
|
||||||
|
|
||||||
Backend() ml.Backend
|
|
||||||
Config() config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultimodalProcessor must be implemented by multimodal models.
|
type config struct {
|
||||||
type MultimodalProcessor interface {
|
Cache kvcache.Cache
|
||||||
// EncodeMultimodal processes a single input (such as an image) and
|
|
||||||
// generates an output (typically an embedding) that can be used by the model.
|
|
||||||
//
|
|
||||||
// The return value is most typically an ml.Tensor, however, different
|
|
||||||
// type are possible, such as an object containing a tensor plus
|
|
||||||
// additional metadata, a slice of tensors or even just the original input.
|
|
||||||
//
|
|
||||||
// The result may be cached by the runner.
|
|
||||||
EncodeMultimodal(ml.Context, []byte) (any, error)
|
|
||||||
|
|
||||||
// PostTokenize is called after tokenization to allow the model to edit the
|
|
||||||
// input stream to correctly arrange multimodal elements.
|
|
||||||
//
|
|
||||||
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
|
|
||||||
// in the order that the user provided them. Each element of the slice will be
|
|
||||||
// either a single token or single multimodal object.
|
|
||||||
//
|
|
||||||
// The model must ensure that inputs are stored according to how they will be
|
|
||||||
// processed and stored in the cache. For example, Llava-style models should insert
|
|
||||||
// placeholder tokens equal to the feature size of the corresponding image with
|
|
||||||
// the image itself attached to and split across these tokens. When Forward is called
|
|
||||||
// a partial subset of these tokens may be submitted according to the batch size.
|
|
||||||
//
|
|
||||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
|
||||||
// that is modified to ensure that there is a unique hash value that accurately
|
|
||||||
// represents the contents.
|
|
||||||
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base implements the common fields and methods for all models
|
// Base implements the common fields and methods for all models
|
||||||
@@ -69,10 +41,6 @@ type Base struct {
|
|||||||
config
|
config
|
||||||
}
|
}
|
||||||
|
|
||||||
type config struct {
|
|
||||||
Cache kvcache.Cache
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backend returns the underlying backend that will run the model
|
// Backend returns the underlying backend that will run the model
|
||||||
func (m *Base) Backend() ml.Backend {
|
func (m *Base) Backend() ml.Backend {
|
||||||
return m.b
|
return m.b
|
||||||
@@ -82,6 +50,14 @@ func (m *Base) Config() config {
|
|||||||
return m.config
|
return m.config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
|
type Model interface {
|
||||||
|
Forward(ml.Context, Options) (ml.Tensor, error)
|
||||||
|
|
||||||
|
Backend() ml.Backend
|
||||||
|
Config() config
|
||||||
|
}
|
||||||
|
|
||||||
var models = make(map[string]func(ml.Config) (Model, error))
|
var models = make(map[string]func(ml.Config) (Model, error))
|
||||||
|
|
||||||
// Register registers a model constructor for the given architecture
|
// Register registers a model constructor for the given architecture
|
||||||
@@ -124,36 +100,6 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
|
||||||
r, err := os.Open(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
meta, _, err := fs.Decode(r, -1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return getTextProcessor(meta.KV())
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
|
|
||||||
arch := kv.Architecture()
|
|
||||||
f, ok := models[arch]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
|
||||||
}
|
|
||||||
m, err := f(kv)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tp, ok := m.(TextProcessor)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
|
||||||
}
|
|
||||||
return tp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
t := v.Type()
|
t := v.Type()
|
||||||
|
|
||||||
@@ -280,7 +226,7 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
||||||
if len(opts.Positions) != len(opts.Sequences) {
|
if len(opts.Positions) != len(opts.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
||||||
}
|
}
|
||||||
@@ -291,7 +237,7 @@ func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
|||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, opts)
|
err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,15 +3,12 @@ package model
|
|||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
fs "github.com/ollama/ollama/fs/ggml"
|
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml"
|
"github.com/ollama/ollama/ml/backend/ggml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseTags(t *testing.T) {
|
func TestParseTags(t *testing.T) {
|
||||||
@@ -137,40 +134,3 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
|||||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetTextProcessor(t *testing.T) {
|
|
||||||
tp, err := getTextProcessor(fs.KV{})
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
|
||||||
|
|
||||||
models["dummy"] = func(ml.Config) (Model, error) {
|
|
||||||
return notTextProcessorModel{}, nil
|
|
||||||
}
|
|
||||||
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
|
|
||||||
if err == nil {
|
|
||||||
t.Error("expected error")
|
|
||||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
} else if tp != nil {
|
|
||||||
t.Error("expected nil tp")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type notTextProcessorModel struct{}
|
|
||||||
|
|
||||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (notTextProcessorModel) Backend() ml.Backend {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (notTextProcessorModel) Config() config {
|
|
||||||
panic("unimplemented")
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
package gemma2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/ml/nn"
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
hiddenSize, numHeads, numKVHeads int
|
|
||||||
attnKeyLen, attnValLen int
|
|
||||||
eps, ropeBase, ropeScale float32
|
|
||||||
attnLogitSoftcap float32
|
|
||||||
finalLogitSoftcap float32
|
|
||||||
largeModelScaling bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Model struct {
|
|
||||||
model.Base
|
|
||||||
model.SentencePieceModel
|
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
|
||||||
Layers []Layer `gguf:"blk"`
|
|
||||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
|
|
||||||
|
|
||||||
*Options
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
gemma27BLayerCount = 46
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
|
||||||
m := Model{
|
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
|
||||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
|
||||||
Options: &Options{
|
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
|
||||||
attnKeyLen: int(c.Uint("attention.key_length")),
|
|
||||||
attnValLen: int(c.Uint("attention.value_length")),
|
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
|
||||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
|
||||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
|
||||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
|
||||||
m.Cache.SetConfig(ml.CacheConfig{})
|
|
||||||
|
|
||||||
return &m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type SelfAttention struct {
|
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
|
||||||
batchSize := hiddenState.Dim(1)
|
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
|
||||||
} else {
|
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
|
||||||
}
|
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
|
||||||
|
|
||||||
cache.Put(ctx, k, v)
|
|
||||||
k, v, mask := cache.Get(ctx)
|
|
||||||
|
|
||||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
kq := k.Mulmat(ctx, q)
|
|
||||||
|
|
||||||
// logit softcap
|
|
||||||
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
|
|
||||||
kq = kq.Tanh(ctx)
|
|
||||||
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
|
|
||||||
|
|
||||||
kq = kq.Add(ctx, mask)
|
|
||||||
kq = kq.Softmax(ctx)
|
|
||||||
|
|
||||||
kqv := v.Mulmat(ctx, kq)
|
|
||||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, kqv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type MLP struct {
|
|
||||||
Up *nn.Linear `gguf:"ffn_up"`
|
|
||||||
Down *nn.Linear `gguf:"ffn_down"`
|
|
||||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Layer struct {
|
|
||||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
||||||
SelfAttention *SelfAttention
|
|
||||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
|
||||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
||||||
MLP *MLP
|
|
||||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
|
||||||
residual := hiddenState
|
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
|
||||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
|
|
||||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
|
||||||
// we need logits for.
|
|
||||||
if outputs != nil {
|
|
||||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
|
||||||
residual = residual.Rows(ctx, outputs)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = hiddenState.Add(ctx, residual)
|
|
||||||
residual = hiddenState
|
|
||||||
|
|
||||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
|
||||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
return hiddenState.Add(ctx, residual)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
|
||||||
|
|
||||||
if len(m.Layers) == gemma27BLayerCount {
|
|
||||||
m.Options.largeModelScaling = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
cacheType := i % 2
|
|
||||||
m.Cache.SetLayer(i)
|
|
||||||
wc := m.Cache.(*kvcache.WrapperCache)
|
|
||||||
wc.SetLayerType(cacheType)
|
|
||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
|
||||||
if i == len(m.Layers)-1 {
|
|
||||||
lastLayerOutputs = outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
||||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
|
||||||
|
|
||||||
// final logit softcap
|
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
|
|
||||||
return hiddenState.Rows(ctx, outputs), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
model.Register("gemma2", New)
|
|
||||||
}
|
|
||||||
@@ -1,177 +0,0 @@
|
|||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"hash/fnv"
|
|
||||||
"image"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/ml/nn"
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Model struct {
|
|
||||||
model.Base
|
|
||||||
model.SentencePieceModel
|
|
||||||
|
|
||||||
*VisionModel `gguf:"v,vision"`
|
|
||||||
*TextModel
|
|
||||||
|
|
||||||
*MultiModalProjector `gguf:"mm"`
|
|
||||||
|
|
||||||
ImageProcessor
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
|
||||||
|
|
||||||
type MultiModalProjector struct {
|
|
||||||
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
|
||||||
InputProjection *nn.Linear `gguf:"mm_input_projection"`
|
|
||||||
|
|
||||||
tokensPerImage int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
|
||||||
l := visionOutputs.Dim(0)
|
|
||||||
|
|
||||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
patchesPerImage := imageSize / patchSize
|
|
||||||
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
|
||||||
|
|
||||||
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
|
||||||
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
|
||||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
|
||||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
|
||||||
|
|
||||||
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
|
||||||
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
|
|
||||||
return visionOutputs
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
|
||||||
m := Model{
|
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
|
||||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
|
||||||
EOS: int32(1),
|
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
|
||||||
EOT: int32(106),
|
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
ImageProcessor: newImageProcessor(c),
|
|
||||||
VisionModel: newVisionModel(c),
|
|
||||||
TextModel: newTextModel(c),
|
|
||||||
MultiModalProjector: &MultiModalProjector{
|
|
||||||
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
|
||||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
|
||||||
|
|
||||||
return &m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
|
||||||
if len(m.VisionModel.Layers) == 0 {
|
|
||||||
return nil, model.ErrNoVisionModel
|
|
||||||
}
|
|
||||||
|
|
||||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
|
||||||
m.ImageProcessor.imageSize,
|
|
||||||
m.ImageProcessor.imageSize,
|
|
||||||
m.ImageProcessor.numChannels,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
|
||||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
|
||||||
return visionOutputs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type imageToken struct {
|
|
||||||
embedding ml.Tensor
|
|
||||||
index int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
|
||||||
var result []input.Input
|
|
||||||
fnvHash := fnv.New64a()
|
|
||||||
|
|
||||||
for _, inp := range inputs {
|
|
||||||
if inp.Multimodal == nil {
|
|
||||||
result = append(result, inp)
|
|
||||||
} else {
|
|
||||||
imageInputs := []input.Input{
|
|
||||||
{Token: 108}, // "\n\n"
|
|
||||||
{Token: 255999}, // "<start_of_image>""
|
|
||||||
}
|
|
||||||
result = append(result, imageInputs...)
|
|
||||||
|
|
||||||
// add image embeddings
|
|
||||||
inputMultimodal := inp.Multimodal.(ml.Tensor)
|
|
||||||
|
|
||||||
for i := range inputMultimodal.Dim(1) {
|
|
||||||
fnvHash.Reset()
|
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
|
|
||||||
fnvHash.Write([]byte{byte(i)})
|
|
||||||
|
|
||||||
imageToken := imageToken{embedding: inputMultimodal, index: i}
|
|
||||||
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
|
|
||||||
}
|
|
||||||
|
|
||||||
result = append(result,
|
|
||||||
input.Input{Token: 256000}, // <end_of_image>
|
|
||||||
input.Input{Token: 108}, // "\n\n"
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
model.Register("gemma3", New)
|
|
||||||
}
|
|
||||||
@@ -1,247 +0,0 @@
|
|||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/ml/nn"
|
|
||||||
"github.com/ollama/ollama/model"
|
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TextOptions struct {
|
|
||||||
hiddenSize, numHeads, numKVHeads int
|
|
||||||
attnKeyLen, attnValLen int
|
|
||||||
eps, ropeScale float32
|
|
||||||
ropeLocalBase, ropeGlobalBase float32
|
|
||||||
largeModelScaling bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextModel struct {
|
|
||||||
model.Base
|
|
||||||
model.SentencePieceModel
|
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
|
||||||
Layers []TextLayer `gguf:"blk"`
|
|
||||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
|
||||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
|
||||||
|
|
||||||
*TextOptions
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
gemmaGlobalCacheCount = 6
|
|
||||||
gemma27BLayerCount = 62
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
cacheTypeSWA = iota
|
|
||||||
cacheTypeCausal
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTextModel(c ml.Config) *TextModel {
|
|
||||||
numBlocks := int(c.Uint("block_count"))
|
|
||||||
|
|
||||||
m := TextModel{
|
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
|
||||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Layers: make([]TextLayer, numBlocks),
|
|
||||||
TextOptions: &TextOptions{
|
|
||||||
hiddenSize: int(c.Uint("embedding_length")),
|
|
||||||
numHeads: int(c.Uint("attention.head_count")),
|
|
||||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
|
||||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
|
||||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
|
||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
|
||||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if numBlocks == gemma27BLayerCount {
|
|
||||||
m.largeModelScaling = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return &m
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextSelfAttention struct {
|
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
|
||||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
|
||||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
|
||||||
batchSize := hiddenState.Dim(1)
|
|
||||||
ropeType := uint32(2)
|
|
||||||
|
|
||||||
ropeBase := opts.ropeLocalBase
|
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
|
||||||
ropeBase = opts.ropeGlobalBase
|
|
||||||
}
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
|
||||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
|
||||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
|
||||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
|
||||||
|
|
||||||
if opts.largeModelScaling {
|
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
|
||||||
} else {
|
|
||||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
|
||||||
}
|
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
|
||||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
|
||||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
|
||||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
|
||||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
|
||||||
|
|
||||||
scaleFactor := 1.0
|
|
||||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
|
||||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, kqv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
||||||
ropeBase := m.TextOptions.ropeLocalBase
|
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
|
||||||
ropeBase = m.TextOptions.ropeGlobalBase
|
|
||||||
}
|
|
||||||
|
|
||||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextMLP struct {
|
|
||||||
Up *nn.Linear `gguf:"ffn_up"`
|
|
||||||
Down *nn.Linear `gguf:"ffn_down"`
|
|
||||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
||||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
|
||||||
return mlp.Down.Forward(ctx, hiddenState)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TextLayer struct {
|
|
||||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
||||||
SelfAttention *TextSelfAttention
|
|
||||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
|
||||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
||||||
MLP *TextMLP
|
|
||||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
|
||||||
residual := hiddenState
|
|
||||||
|
|
||||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
|
|
||||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
|
|
||||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
|
||||||
// we need logits for.
|
|
||||||
if outputs != nil {
|
|
||||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
|
||||||
residual = residual.Rows(ctx, outputs)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = hiddenState.Add(ctx, residual)
|
|
||||||
residual = hiddenState
|
|
||||||
|
|
||||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
|
||||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
return hiddenState.Add(ctx, residual)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
|
||||||
var embedding ml.Tensor
|
|
||||||
var src, dst, length int
|
|
||||||
var except []int
|
|
||||||
|
|
||||||
for _, image := range multimodal {
|
|
||||||
imageToken := image.Multimodal.(imageToken)
|
|
||||||
imageSrc := imageToken.index
|
|
||||||
imageDst := image.Index
|
|
||||||
|
|
||||||
if embedding == nil {
|
|
||||||
embedding = imageToken.embedding
|
|
||||||
src = imageSrc
|
|
||||||
dst = imageDst
|
|
||||||
length = 1
|
|
||||||
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
|
|
||||||
src = imageSrc
|
|
||||||
dst = imageDst
|
|
||||||
length++
|
|
||||||
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
|
|
||||||
length++
|
|
||||||
} else {
|
|
||||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
|
||||||
|
|
||||||
embedding = imageToken.embedding
|
|
||||||
src = imageSrc
|
|
||||||
dst = imageDst
|
|
||||||
length = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
except = append(except, imageDst)
|
|
||||||
}
|
|
||||||
|
|
||||||
if embedding != nil {
|
|
||||||
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
|
|
||||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
|
|
||||||
}
|
|
||||||
|
|
||||||
return except
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
|
||||||
|
|
||||||
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
|
||||||
// gemma alternates between the sliding window (local) and causal (global)
|
|
||||||
// kv cache every 6 layers
|
|
||||||
cacheType := cacheTypeSWA
|
|
||||||
if (i+1)%gemmaGlobalCacheCount == 0 {
|
|
||||||
cacheType = cacheTypeCausal
|
|
||||||
}
|
|
||||||
cache.SetLayer(i)
|
|
||||||
wc := cache.(*kvcache.WrapperCache)
|
|
||||||
wc.SetLayerType(cacheType)
|
|
||||||
|
|
||||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
|
||||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
|
||||||
}
|
|
||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
|
||||||
if i == len(m.Layers)-1 {
|
|
||||||
lastLayerOutputs = outputs
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
||||||
return m.Output.Forward(ctx, hiddenState)
|
|
||||||
}
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/ml/nn"
|
|
||||||
)
|
|
||||||
|
|
||||||
var batchSize int = 1
|
|
||||||
|
|
||||||
type VisionSelfAttention struct {
|
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
|
||||||
|
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
|
||||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
|
||||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
|
||||||
|
|
||||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
|
||||||
|
|
||||||
hiddenState = sa.Output.Forward(ctx, attention)
|
|
||||||
return hiddenState
|
|
||||||
}
|
|
||||||
|
|
||||||
type VisionMLP struct {
|
|
||||||
FC1 *nn.Linear `gguf:"fc1"`
|
|
||||||
FC2 *nn.Linear `gguf:"fc2"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
||||||
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
|
|
||||||
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
|
|
||||||
return hiddenState
|
|
||||||
}
|
|
||||||
|
|
||||||
type VisionEncoderLayer struct {
|
|
||||||
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
|
||||||
SelfAttention *VisionSelfAttention
|
|
||||||
|
|
||||||
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
|
||||||
MLP *VisionMLP `gguf:"mlp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
||||||
residual := hiddenState
|
|
||||||
|
|
||||||
// self attention
|
|
||||||
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
|
||||||
hiddenState = hiddenState.Add(ctx, residual)
|
|
||||||
residual = hiddenState
|
|
||||||
|
|
||||||
// feed forward
|
|
||||||
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
|
||||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
|
||||||
return hiddenState.Add(ctx, residual)
|
|
||||||
}
|
|
||||||
|
|
||||||
type VisionModelOptions struct {
|
|
||||||
hiddenSize, numHeads int
|
|
||||||
imageSize, patchSize int
|
|
||||||
eps float32
|
|
||||||
}
|
|
||||||
|
|
||||||
type VisionModel struct {
|
|
||||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
|
|
||||||
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
|
|
||||||
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
|
|
||||||
|
|
||||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
|
||||||
|
|
||||||
*VisionModelOptions
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|
||||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
|
||||||
|
|
||||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
|
||||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
|
||||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
||||||
|
|
||||||
positions := make([]int32, numPatches)
|
|
||||||
for i := range positions {
|
|
||||||
positions[i] = int32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
|
|
||||||
|
|
||||||
for _, layer := range m.Layers {
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
|
||||||
return hiddenState
|
|
||||||
}
|
|
||||||
|
|
||||||
func newVisionModel(c ml.Config) *VisionModel {
|
|
||||||
return &VisionModel{
|
|
||||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
|
||||||
VisionModelOptions: &VisionModelOptions{
|
|
||||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
|
||||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
|
||||||
|
|
||||||
imageSize: int(c.Uint("vision.image_size")),
|
|
||||||
patchSize: int(c.Uint("vision.patch_size")),
|
|
||||||
|
|
||||||
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
package gemma3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"image"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/ml"
|
|
||||||
"github.com/ollama/ollama/model/imageproc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ImageProcessor struct {
|
|
||||||
imageSize, patchSize, numChannels int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
|
||||||
return ImageProcessor{
|
|
||||||
imageSize: int(c.Uint("vision.image_size")),
|
|
||||||
patchSize: int(c.Uint("vision.patch_size")),
|
|
||||||
numChannels: int(c.Uint("vision.num_channels")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
|
|
||||||
var pixelVals, rVals, gVals, bVals []float32
|
|
||||||
|
|
||||||
bounds := img.Bounds()
|
|
||||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
|
||||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
|
||||||
c := img.At(x, y)
|
|
||||||
r, g, b, _ := c.RGBA()
|
|
||||||
rVal := float32(r>>8) / 255.0
|
|
||||||
gVal := float32(g>>8) / 255.0
|
|
||||||
bVal := float32(b>>8) / 255.0
|
|
||||||
|
|
||||||
rVal = (rVal - mean[0]) / std[0]
|
|
||||||
gVal = (gVal - mean[1]) / std[1]
|
|
||||||
bVal = (bVal - mean[2]) / std[2]
|
|
||||||
|
|
||||||
rVals = append(rVals, rVal)
|
|
||||||
gVals = append(gVals, gVal)
|
|
||||||
bVals = append(bVals, bVal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pixelVals = append(pixelVals, rVals...)
|
|
||||||
pixelVals = append(pixelVals, gVals...)
|
|
||||||
pixelVals = append(pixelVals, bVals...)
|
|
||||||
|
|
||||||
return pixelVals
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
|
|
||||||
outputSize := image.Point{p.imageSize, p.imageSize}
|
|
||||||
newImage := imageproc.Composite(img)
|
|
||||||
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
|
|
||||||
|
|
||||||
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
@@ -1,18 +1,16 @@
|
|||||||
package llama
|
package llama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
@@ -31,10 +29,6 @@ type Model struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
|
||||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
|
||||||
}
|
|
||||||
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
@@ -66,38 +60,43 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SelfAttention struct {
|
type SelfAttention struct {
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
q := sa.Query.Forward(ctx, hiddenState)
|
q := sa.Query.Forward(ctx, hiddenState)
|
||||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
k := sa.Key.Forward(ctx, hiddenState)
|
k := sa.Key.Forward(ctx, hiddenState)
|
||||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
v := sa.Value.Forward(ctx, hiddenState)
|
v := sa.Value.Forward(ctx, hiddenState)
|
||||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
cache.Put(ctx, k, v)
|
||||||
|
k, v, mask := cache.Get(ctx)
|
||||||
|
|
||||||
|
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, kqv)
|
return sa.Output.Forward(ctx, kqv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
@@ -139,18 +138,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,18 +1,10 @@
|
|||||||
package mllama
|
package mllama
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"hash/fnv"
|
|
||||||
"image"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/ml/nn"
|
"github.com/ollama/ollama/ml/nn"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
@@ -33,10 +25,6 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
// Verify unified config
|
|
||||||
if c.Uint("vision.block_count") == 0 {
|
|
||||||
return nil, fmt.Errorf("non-unified vision model not supported")
|
|
||||||
}
|
|
||||||
m := Model{
|
m := Model{
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
@@ -55,103 +43,59 @@ func New(c ml.Config) (model.Model, error) {
|
|||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
encoderCache := kvcache.NewEncoderCache()
|
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
||||||
encoderCache.SetConfig(ml.CacheConfig{})
|
|
||||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
|
||||||
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
|
|
||||||
return nil, model.ErrNoVisionModel
|
|
||||||
}
|
|
||||||
|
|
||||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
|
||||||
m.ImageProcessor.imageSize,
|
|
||||||
m.ImageProcessor.imageSize,
|
|
||||||
m.ImageProcessor.numChannels,
|
|
||||||
m.ImageProcessor.maxNumTiles,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
positions := make([]int32, 1601)
|
|
||||||
for i := range positions {
|
|
||||||
positions[i] = int32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
|
||||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
|
||||||
var images []input.Input
|
|
||||||
fnvHash := fnv.New64a()
|
|
||||||
|
|
||||||
for i := range inputs {
|
|
||||||
if inputs[i].Multimodal == nil {
|
|
||||||
if len(images) > 0 {
|
|
||||||
inputs[i].Multimodal = images[0].Multimodal
|
|
||||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
|
||||||
for j := 1; j < len(images); j++ {
|
|
||||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
|
||||||
fnvHash.Reset()
|
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
|
||||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
|
||||||
inputs[i].MultimodalHash = fnvHash.Sum64()
|
|
||||||
}
|
|
||||||
images = nil
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
images = append(images, inputs[i])
|
|
||||||
inputs[i].Token = -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
|
|
||||||
|
|
||||||
return inputs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
|
||||||
var crossAttentionStates ml.Tensor
|
var crossAttentionStates ml.Tensor
|
||||||
if len(opts.Multimodal) > 0 {
|
if opts.Images != nil {
|
||||||
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
|
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pixelValues, err := ctx.FromFloatSlice(f32s,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
m.ImageProcessor.imageSize,
|
||||||
|
m.ImageProcessor.numChannels,
|
||||||
|
m.ImageProcessor.maxNumTiles,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
positions := make([]int32, 1601)
|
||||||
|
for i := range positions {
|
||||||
|
positions[i] = int32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||||
|
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,31 +10,36 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TextSelfAttention struct {
|
type TextSelfAttention struct {
|
||||||
Query *nn.Linear `gguf:"attn_q"`
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
Key *nn.Linear `gguf:"attn_k"`
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
Value *nn.Linear `gguf:"attn_v"`
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
Output *nn.Linear `gguf:"attn_output"`
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||||
batchSize := hiddenState.Dim(1)
|
batchSize := hiddenState.Dim(1)
|
||||||
headDim := opts.hiddenSize / opts.numHeads
|
headDim := opts.hiddenSize / opts.numHeads
|
||||||
ropeType := uint32(0)
|
|
||||||
|
|
||||||
query := sa.Query.Forward(ctx, hiddenState)
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
key := sa.Key.Forward(ctx, hiddenState)
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||||
|
|
||||||
value := sa.Value.Forward(ctx, hiddenState)
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
cache.Put(ctx, key, value)
|
||||||
|
key, value, mask := cache.Get(ctx)
|
||||||
|
|
||||||
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return sa.Output.Forward(ctx, attention)
|
return sa.Output.Forward(ctx, attention)
|
||||||
@@ -42,11 +47,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
|||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
// This will only get called for layers in the cache, which are just the self attention layers
|
// This will only get called for layers in the cache, which are just the self attention layers
|
||||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
@@ -106,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
|
||||||
var key, value ml.Tensor
|
var key, value, mask ml.Tensor
|
||||||
if crossAttentionStates != nil {
|
if crossAttentionStates != nil {
|
||||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||||
|
|
||||||
@@ -118,23 +119,16 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||||
|
|
||||||
cache.Put(ctx, key, value)
|
cache.Put(ctx, key, value)
|
||||||
|
} else {
|
||||||
|
key, value, mask = cache.Get(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
key, value, _ = cache.Get(ctx)
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
|
||||||
|
|
||||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
|
||||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||||
|
|
||||||
kq := key.MulmatFullPrec(ctx, query)
|
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||||
|
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||||
kq = kq.Scale(ctx, scaleFactor)
|
|
||||||
kq = kq.Softmax(ctx)
|
|
||||||
|
|
||||||
kqv := value.Mulmat(ctx, kq)
|
|
||||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
||||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
|
||||||
return ca.Output.Forward(ctx, attention)
|
return ca.Output.Forward(ctx, attention)
|
||||||
@@ -197,6 +191,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TextModelOptions struct {
|
type TextModelOptions struct {
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
|
||||||
hiddenSize, numHeads, numKVHeads int
|
hiddenSize, numHeads, numKVHeads int
|
||||||
eps, ropeBase, ropeScale float32
|
eps, ropeBase, ropeScale float32
|
||||||
ropeDim uint32
|
ropeDim uint32
|
||||||
|
|||||||
@@ -144,6 +144,8 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
|
|||||||
return images
|
return images
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove the "alpha" channel by drawing over a prefilled image
|
||||||
|
//
|
||||||
// remove the "alpha" channel by drawing over a prefilled image
|
// remove the "alpha" channel by drawing over a prefilled image
|
||||||
//
|
//
|
||||||
//nolint:unused
|
//nolint:unused
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
|
||||||
_ "github.com/ollama/ollama/model/models/llama"
|
_ "github.com/ollama/ollama/model/models/llama"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"cmp"
|
"cmp"
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -19,17 +18,8 @@ const (
|
|||||||
SpecialEOS
|
SpecialEOS
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
TOKEN_TYPE_NORMAL = iota + 1
|
|
||||||
TOKEN_TYPE_UNKNOWN
|
|
||||||
TOKEN_TYPE_CONTROL
|
|
||||||
TOKEN_TYPE_USER_DEFINED
|
|
||||||
TOKEN_TYPE_UNUSED
|
|
||||||
TOKEN_TYPE_BYTE
|
|
||||||
)
|
|
||||||
|
|
||||||
type TextProcessor interface {
|
type TextProcessor interface {
|
||||||
Encode(s string, addSpecial bool) ([]int32, error)
|
Encode(string) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(int32, Special) bool
|
Is(int32, Special) bool
|
||||||
}
|
}
|
||||||
@@ -37,11 +27,11 @@ type TextProcessor interface {
|
|||||||
type Vocabulary struct {
|
type Vocabulary struct {
|
||||||
Values []string
|
Values []string
|
||||||
Types []uint32
|
Types []uint32
|
||||||
Scores []float32
|
Scores []uint32
|
||||||
Merges []string
|
Merges []string
|
||||||
|
|
||||||
BOS, EOS, EOT int32
|
BOS, EOS int32
|
||||||
AddBOS, AddEOS, AddEOT bool
|
AddBOS, AddEOS bool
|
||||||
|
|
||||||
specialOnce sync.Once
|
specialOnce sync.Once
|
||||||
special []string
|
special []string
|
||||||
@@ -58,7 +48,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
|
|||||||
case SpecialBOS:
|
case SpecialBOS:
|
||||||
return id == v.BOS
|
return id == v.BOS
|
||||||
case SpecialEOS:
|
case SpecialEOS:
|
||||||
return id == v.EOS || id == v.EOT
|
return id == v.EOS
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -86,9 +76,7 @@ func (v *Vocabulary) Decode(id int32) string {
|
|||||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||||
v.specialOnce.Do(func() {
|
v.specialOnce.Do(func() {
|
||||||
for i := range v.Values {
|
for i := range v.Values {
|
||||||
if slices.Contains([]int{105, 106}, i) {
|
if v.Types[i] == 3 {
|
||||||
v.special = append(v.special, v.Values[i])
|
|
||||||
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
|
|
||||||
v.special = append(v.special, v.Values[i])
|
v.special = append(v.special, v.Values[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -156,7 +144,7 @@ type merge struct {
|
|||||||
runes []rune
|
runes []rune
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||||
// TODO: process special tokens concurrently
|
// TODO: process special tokens concurrently
|
||||||
@@ -189,6 +177,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
for _, frag := range fragments {
|
for _, frag := range fragments {
|
||||||
if len(frag.ids) > 0 {
|
if len(frag.ids) > 0 {
|
||||||
ids = append(ids, frag.ids...)
|
ids = append(ids, frag.ids...)
|
||||||
|
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,6 +201,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
// short circuit if the fragment is in the vocabulary
|
// short circuit if the fragment is in the vocabulary
|
||||||
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
|
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,13 +275,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
// TODO: handle the edge case where the rune isn't in the vocabulary
|
// TODO: handle the edge case where the rune isn't in the vocabulary
|
||||||
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
||||||
ids = append(ids, id)
|
ids = append(ids, id)
|
||||||
|
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if len(ids) > 0 {
|
||||||
if bpe.vocab.AddBOS {
|
if bpe.vocab.AddBOS {
|
||||||
if ids[0] == bpe.vocab.BOS {
|
if ids[0] == bpe.vocab.BOS {
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
||||||
@@ -338,5 +329,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,249 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"iter"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
|
||||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
|
||||||
|
|
||||||
var log = logging.NewLogger()
|
|
||||||
|
|
||||||
func replaceWhitespaceBySeperator(s string) string {
|
|
||||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SentencePieceModel struct {
|
|
||||||
maxTokenLen int
|
|
||||||
pre *regexp2.Regexp
|
|
||||||
vocab *Vocabulary
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
|
||||||
|
|
||||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
|
||||||
log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
|
||||||
|
|
||||||
counter := map[int]int{}
|
|
||||||
var maxTokenLen int
|
|
||||||
for cnt := range vocab.Types {
|
|
||||||
switch vocab.Types[cnt] {
|
|
||||||
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
|
||||||
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
|
||||||
fallthrough
|
|
||||||
default:
|
|
||||||
counter[int(vocab.Types[cnt])] += 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
|
||||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
|
||||||
"max token len", maxTokenLen)
|
|
||||||
|
|
||||||
return SentencePieceModel{
|
|
||||||
maxTokenLen: maxTokenLen,
|
|
||||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
|
||||||
vocab: vocab,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
|
||||||
return spm.vocab.Is(id, special)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
|
||||||
return func(yield func(string) bool) {
|
|
||||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
|
||||||
if !yield(m.String()) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
|
||||||
fragments := []fragment{{value: s}}
|
|
||||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
|
||||||
// TODO: process special tokens concurrently
|
|
||||||
id := spm.vocab.Encode(special)
|
|
||||||
for i := 0; i < len(fragments); i++ {
|
|
||||||
frag := fragments[i]
|
|
||||||
if len(frag.ids) > 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var middle []fragment
|
|
||||||
switch i := strings.Index(frag.value, special); {
|
|
||||||
case i < 0:
|
|
||||||
middle = append(middle, frag)
|
|
||||||
case i > 0:
|
|
||||||
middle = append(middle, fragment{value: frag.value[:i]})
|
|
||||||
fallthrough
|
|
||||||
default:
|
|
||||||
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
|
||||||
if rest := frag.value[i+len(special):]; rest != "" {
|
|
||||||
middle = append(middle, fragment{value: rest})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
log.Trace("fragments", "frags", fragments)
|
|
||||||
|
|
||||||
var ids []int32
|
|
||||||
for _, frag := range fragments {
|
|
||||||
if len(frag.ids) > 0 {
|
|
||||||
ids = append(ids, frag.ids...)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for split := range spm.split(frag.value) {
|
|
||||||
split = replaceWhitespaceBySeperator(split)
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.Write([]byte(split))
|
|
||||||
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
|
||||||
ids = append(ids, id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
runes := []rune(sb.String())
|
|
||||||
pq := queue.NewWith(func(a, b any) int {
|
|
||||||
priA := a.(*candidate)
|
|
||||||
priB := b.(*candidate)
|
|
||||||
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
})
|
|
||||||
|
|
||||||
merges := make([]merge, len(runes))
|
|
||||||
for r := range runes {
|
|
||||||
merges[r] = merge{
|
|
||||||
p: r - 1,
|
|
||||||
n: r + 1,
|
|
||||||
runes: []rune{runes[r]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace("tokenizer", "merges", merges)
|
|
||||||
|
|
||||||
pairwise := func(a, b int) *candidate {
|
|
||||||
if a < 0 || b >= len(runes) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
|
||||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
|
||||||
return &candidate{
|
|
||||||
a: a,
|
|
||||||
b: b,
|
|
||||||
score: spm.vocab.Scores[id],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range len(runes) - 1 {
|
|
||||||
if pair := pairwise(i, i+1); pair != nil {
|
|
||||||
pq.Enqueue(pair)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pqv := pq.Values()
|
|
||||||
for _, v := range pqv {
|
|
||||||
e := v.(*candidate)
|
|
||||||
log.Trace("candidate", "candidate", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
for !pq.Empty() {
|
|
||||||
v, _ := pq.Dequeue()
|
|
||||||
pair := v.(*candidate)
|
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
|
||||||
|
|
||||||
log.Trace("pair", "left", left, "right", right)
|
|
||||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
|
||||||
merges[pair.b].runes = nil
|
|
||||||
merges[pair.a].n = right.n
|
|
||||||
if right.n < len(merges) {
|
|
||||||
merges[right.n].p = pair.a
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
||||||
pq.Enqueue(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
|
||||||
pq.Enqueue(pair)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace("merges", "merges", merges)
|
|
||||||
|
|
||||||
for _, merge := range merges {
|
|
||||||
if len(merge.runes) > 0 {
|
|
||||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
||||||
ids = append(ids, id)
|
|
||||||
} else {
|
|
||||||
log.Error("missing token", "token", string(merge.runes))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
|
||||||
if spm.vocab.AddBOS {
|
|
||||||
if ids[0] == spm.vocab.BOS {
|
|
||||||
log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
|
||||||
ids = append([]int32{spm.vocab.BOS}, ids...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if spm.vocab.AddEOS {
|
|
||||||
if ids[len(ids)-1] == spm.vocab.EOS {
|
|
||||||
log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
|
||||||
ids = append(ids, spm.vocab.EOS)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ids, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type candidate struct {
|
|
||||||
a, b int
|
|
||||||
score float32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|
||||||
var sb strings.Builder
|
|
||||||
for _, id := range ids {
|
|
||||||
data := spm.vocab.Decode(id)
|
|
||||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
|
||||||
if _, err := sb.WriteString(data); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("decoded", "ids", ids, "text", sb.String())
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"slices"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/convert/sentencepiece"
|
|
||||||
)
|
|
||||||
|
|
||||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var spm sentencepiece.ModelProto
|
|
||||||
if err := proto.Unmarshal(bts, &spm); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
|
||||||
|
|
||||||
var v Vocabulary
|
|
||||||
|
|
||||||
for _, piece := range spm.GetPieces() {
|
|
||||||
v.Values = append(v.Values, piece.GetPiece())
|
|
||||||
v.Scores = append(v.Scores, piece.GetScore())
|
|
||||||
switch t := piece.GetType(); t {
|
|
||||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
|
||||||
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
|
||||||
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
|
||||||
sentencepiece.ModelProto_SentencePiece_BYTE:
|
|
||||||
v.Types = append(v.Types, uint32(t))
|
|
||||||
default:
|
|
||||||
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
|
||||||
// todo parse the special tokens file
|
|
||||||
// - this will roundtrip correctly but the <start_of_turn> and
|
|
||||||
// <end_of_turn> tokens aren't processed
|
|
||||||
v.Types = append(v.Types, tt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return NewSentencePieceModel(preTokenizer, &v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSentencePieceEncode(t *testing.T) {
|
|
||||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
|
||||||
slog.SetDefault(logger)
|
|
||||||
|
|
||||||
tokenizer := loadSentencePieceVocab(t)
|
|
||||||
|
|
||||||
t.Run("basic roundtrip", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cases := []string{
|
|
||||||
"hello",
|
|
||||||
"hello ",
|
|
||||||
"hello ",
|
|
||||||
" hello",
|
|
||||||
" hello ",
|
|
||||||
" hello ",
|
|
||||||
"hello world",
|
|
||||||
"请考试我的软件!12345",
|
|
||||||
"你好",
|
|
||||||
"Hello 你好 world!",
|
|
||||||
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
|
||||||
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
|
||||||
"Numbers and symbols: 123456789 +- */",
|
|
||||||
"Special tokens: <bos> text <eos>",
|
|
||||||
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
|
||||||
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
|
||||||
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
|
||||||
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, want := range cases {
|
|
||||||
ids, err := tokenizer.Encode(want, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, err := tokenizer.Decode(ids); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
} else if got != want {
|
|
||||||
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("special tokens", func(t *testing.T) {
|
|
||||||
type candidate struct {
|
|
||||||
token string
|
|
||||||
ids []int32
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []candidate{
|
|
||||||
{"<bos>", []int32{2}},
|
|
||||||
{"<eos>", []int32{1}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, want := range cases {
|
|
||||||
ids, err := tokenizer.Encode(want.token, true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if !slices.Equal(ids, want.ids) {
|
|
||||||
t.Errorf("got %#v, want %#v", ids, want.ids)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
|
|||||||
t.Run("simple", func(t *testing.T) {
|
t.Run("simple", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ids, err := tokenizer.Encode("hello world", true)
|
ids, err := tokenizer.Encode("hello world")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
|
|||||||
t.Errorf("got %q, want hello world", s)
|
t.Errorf("got %q, want hello world", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for s, want := range cases {
|
for s, want := range cases {
|
||||||
ids, err := tokenizer.Encode(s, true)
|
ids, err := tokenizer.Encode(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, want := range cases {
|
for _, want := range cases {
|
||||||
ids, err := tokenizer.Encode(want, true)
|
ids, err := tokenizer.Encode(want)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for s, want := range cases {
|
for s, want := range cases {
|
||||||
ids, err := tokenizer.Encode(s, true)
|
ids, err := tokenizer.Encode(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|||||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for range b.N {
|
for range b.N {
|
||||||
_, err := tokenizer.Encode(string(bts), true)
|
_, err := tokenizer.Encode(string(bts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
ids, err := tokenizer.Encode(string(bts), true)
|
ids, err := tokenizer.Encode(string(bts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
BIN
model/testdata/gemma2/tokenizer.model
vendored
BIN
model/testdata/gemma2/tokenizer.model
vendored
Binary file not shown.
@@ -116,9 +116,19 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
|
|
||||||
switch r {
|
switch r {
|
||||||
case KeyUp:
|
case KeyUp:
|
||||||
i.historyPrev(buf, ¤tLineBuf)
|
if i.History.Pos > 0 {
|
||||||
|
if i.History.Pos == i.History.Size() {
|
||||||
|
currentLineBuf = []rune(buf.String())
|
||||||
|
}
|
||||||
|
buf.Replace([]rune(i.History.Prev()))
|
||||||
|
}
|
||||||
case KeyDown:
|
case KeyDown:
|
||||||
i.historyNext(buf, ¤tLineBuf)
|
if i.History.Pos < i.History.Size() {
|
||||||
|
buf.Replace([]rune(i.History.Next()))
|
||||||
|
if i.History.Pos == i.History.Size() {
|
||||||
|
buf.Replace(currentLineBuf)
|
||||||
|
}
|
||||||
|
}
|
||||||
case KeyLeft:
|
case KeyLeft:
|
||||||
buf.MoveLeft()
|
buf.MoveLeft()
|
||||||
case KeyRight:
|
case KeyRight:
|
||||||
@@ -175,10 +185,6 @@ func (i *Instance) Readline() (string, error) {
|
|||||||
esc = true
|
esc = true
|
||||||
case CharInterrupt:
|
case CharInterrupt:
|
||||||
return "", ErrInterrupt
|
return "", ErrInterrupt
|
||||||
case CharPrev:
|
|
||||||
i.historyPrev(buf, ¤tLineBuf)
|
|
||||||
case CharNext:
|
|
||||||
i.historyNext(buf, ¤tLineBuf)
|
|
||||||
case CharLineStart:
|
case CharLineStart:
|
||||||
buf.MoveToStart()
|
buf.MoveToStart()
|
||||||
case CharLineEnd:
|
case CharLineEnd:
|
||||||
@@ -240,24 +246,6 @@ func (i *Instance) HistoryDisable() {
|
|||||||
i.History.Enabled = false
|
i.History.Enabled = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) {
|
|
||||||
if i.History.Pos > 0 {
|
|
||||||
if i.History.Pos == i.History.Size() {
|
|
||||||
*currentLineBuf = []rune(buf.String())
|
|
||||||
}
|
|
||||||
buf.Replace([]rune(i.History.Prev()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) {
|
|
||||||
if i.History.Pos < i.History.Size() {
|
|
||||||
buf.Replace([]rune(i.History.Next()))
|
|
||||||
if i.History.Pos == i.History.Size() {
|
|
||||||
buf.Replace(*currentLineBuf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTerminal() (*Terminal, error) {
|
func NewTerminal() (*Terminal, error) {
|
||||||
fd := os.Stdin.Fd()
|
fd := os.Stdin.Fd()
|
||||||
termios, err := SetRawMode(fd)
|
termios, err := SetRawMode(fd)
|
||||||
|
|||||||
@@ -931,6 +931,7 @@ func Execute(args []string) error {
|
|||||||
slog.Info("starting go runner")
|
slog.Info("starting go runner")
|
||||||
|
|
||||||
llama.BackendInit()
|
llama.BackendInit()
|
||||||
|
slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
|
||||||
|
|
||||||
server := &Server{
|
server := &Server{
|
||||||
batchSize: *batchSize,
|
batchSize: *batchSize,
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/kvcache"
|
"github.com/ollama/ollama/kvcache"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InputCache struct {
|
type InputCache struct {
|
||||||
@@ -39,7 +39,10 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||||||
slots := make([]InputCacheSlot, numSlots)
|
slots := make([]InputCacheSlot, numSlots)
|
||||||
|
|
||||||
for i := range slots {
|
for i := range slots {
|
||||||
slots[i] = InputCacheSlot{Id: i}
|
slots[i] = InputCacheSlot{
|
||||||
|
Id: i,
|
||||||
|
Inputs: make([]input, 0),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cache := model.Config().Cache
|
cache := model.Config().Cache
|
||||||
@@ -59,9 +62,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
|
|||||||
func kvCacheTypeFromStr(s string) ml.DType {
|
func kvCacheTypeFromStr(s string) ml.DType {
|
||||||
switch s {
|
switch s {
|
||||||
case "q8_0":
|
case "q8_0":
|
||||||
return ml.DTypeQ80
|
panic("kv cache quantization not yet implemented")
|
||||||
case "q4_0":
|
case "q4_0":
|
||||||
return ml.DTypeQ40
|
panic("kv cache quantization not yet implemented")
|
||||||
default:
|
default:
|
||||||
return ml.DTypeF16
|
return ml.DTypeF16
|
||||||
}
|
}
|
||||||
@@ -80,7 +83,7 @@ type InputCacheSlot struct {
|
|||||||
Id int
|
Id int
|
||||||
|
|
||||||
// Inputs that are stored in the KV cache
|
// Inputs that are stored in the KV cache
|
||||||
Inputs []input.Input
|
Inputs []input
|
||||||
|
|
||||||
// is this cache actively being processed as part of a sequence?
|
// is this cache actively being processed as part of a sequence?
|
||||||
InUse bool
|
InUse bool
|
||||||
@@ -89,7 +92,7 @@ type InputCacheSlot struct {
|
|||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
|
||||||
var slot *InputCacheSlot
|
var slot *InputCacheSlot
|
||||||
var numPast int32
|
var numPast int32
|
||||||
var err error
|
var err error
|
||||||
@@ -140,7 +143,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
|
|||||||
return slot, prompt, nil
|
return slot, prompt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
|
||||||
longest := int32(-1)
|
longest := int32(-1)
|
||||||
var longestSlot *InputCacheSlot
|
var longestSlot *InputCacheSlot
|
||||||
|
|
||||||
@@ -163,7 +166,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
|||||||
return longestSlot, longest, nil
|
return longestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
|
||||||
oldest := time.Now()
|
oldest := time.Now()
|
||||||
var oldestSlot *InputCacheSlot
|
var oldestSlot *InputCacheSlot
|
||||||
|
|
||||||
@@ -199,7 +202,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||||||
if longest > 0 && longestSlot != oldestSlot {
|
if longest > 0 && longestSlot != oldestSlot {
|
||||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||||
len(longestSlot.Inputs))
|
len(longestSlot.Inputs))
|
||||||
oldestSlot.Inputs = make([]input.Input, longest)
|
oldestSlot.Inputs = make([]input, longest)
|
||||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||||
@@ -209,7 +212,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||||||
return oldestSlot, longest, nil
|
return oldestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
func countCommonPrefix(a []input, b []input) int32 {
|
||||||
var count int32
|
var count int32
|
||||||
|
|
||||||
for i := range a {
|
for i := range a {
|
||||||
@@ -217,7 +220,7 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
|
if !reflect.DeepEqual(a[i], b[i]) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import (
|
|||||||
"image"
|
"image"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCountCommon(t *testing.T) {
|
func TestCountCommon(t *testing.T) {
|
||||||
@@ -15,50 +13,44 @@ func TestCountCommon(t *testing.T) {
|
|||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
t1 []input.Input
|
t1 []input
|
||||||
t2 []input.Input
|
t2 []input
|
||||||
expected int32
|
expected int32
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Equal",
|
name: "Equal",
|
||||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t1: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
expected: 3,
|
expected: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Prefix",
|
name: "Prefix",
|
||||||
t1: []input.Input{{Token: 1}},
|
t1: []input{{token: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Image Prefix",
|
name: "Image Prefix",
|
||||||
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
|
t1: []input{{image: imgA}},
|
||||||
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
|
t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed",
|
name: "Mixed",
|
||||||
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
t1: []input{{token: 1}, {image: imgA}},
|
||||||
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
|
t2: []input{{token: 1}, {image: imgA}, {token: 5}},
|
||||||
expected: 2,
|
expected: 2,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Mixed, Same Length",
|
|
||||||
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
|
||||||
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
|
|
||||||
expected: 1,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Empty",
|
name: "Empty",
|
||||||
t1: []input.Input{},
|
t1: []input{},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both Empty",
|
name: "Both Empty",
|
||||||
t1: []input.Input{},
|
t1: []input{},
|
||||||
t2: []input.Input{},
|
t2: []input{},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -82,7 +74,7 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []input
|
||||||
longest expected
|
longest expected
|
||||||
best expected
|
best expected
|
||||||
}{
|
}{
|
||||||
@@ -91,18 +83,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{},
|
Inputs: []input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []input{{token: 1}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 0, len: 0},
|
best: expected{result: 0, len: 0},
|
||||||
},
|
},
|
||||||
@@ -111,18 +103,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []input{{token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []input{{token: 1}, {token: 2}},
|
||||||
longest: expected{result: 1, len: 2},
|
longest: expected{result: 1, len: 2},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
@@ -131,18 +123,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}},
|
prompt: []input{{token: 2}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
@@ -152,19 +144,19 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []input{{token: 1}},
|
||||||
longest: expected{result: 0, len: 1},
|
longest: expected{result: 0, len: 1},
|
||||||
best: expected{result: 1, len: 1},
|
best: expected{result: 1, len: 1},
|
||||||
},
|
},
|
||||||
@@ -173,18 +165,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []input{{token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
prompt: []input{{token: 2}, {token: 3}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
@@ -193,18 +185,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []input{{token: 1}, {token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []input{{token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []input{{token: 1}, {token: 2}},
|
||||||
longest: expected{result: 1, len: 1},
|
longest: expected{result: 1, len: 1},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
package ollamarunner
|
package ollamarunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/maphash"
|
"image"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
@@ -26,26 +27,28 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
"github.com/ollama/ollama/model"
|
"github.com/ollama/ollama/model"
|
||||||
"github.com/ollama/ollama/model/input"
|
|
||||||
"github.com/ollama/ollama/runner/common"
|
"github.com/ollama/ollama/runner/common"
|
||||||
"github.com/ollama/ollama/sample"
|
"github.com/ollama/ollama/sample"
|
||||||
|
|
||||||
_ "github.com/ollama/ollama/model/models"
|
_ "github.com/ollama/ollama/model/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Sequence struct {
|
// input is an element of the prompt to process, either a token or an image
|
||||||
// ctx for allocating tensors that last the lifetime of the sequence, such as
|
type input struct {
|
||||||
// multimodal embeddings
|
token int32
|
||||||
ctx ml.Context
|
|
||||||
|
|
||||||
|
image image.Image
|
||||||
|
}
|
||||||
|
|
||||||
|
type Sequence struct {
|
||||||
// batch index
|
// batch index
|
||||||
iBatch int
|
iBatch int
|
||||||
|
|
||||||
// prompt inputs left to evaluate
|
// prompt inputs left to evaluate
|
||||||
inputs []input.Input
|
inputs []input
|
||||||
|
|
||||||
// inputs that have been added to a batch but not yet submitted to Forward
|
// inputs that have been added to a batch but not yet submitted to Forward
|
||||||
pendingInputs []input.Input
|
pendingInputs []input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []string
|
||||||
@@ -98,9 +101,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
ctx := s.model.Backend().NewContext()
|
|
||||||
|
|
||||||
inputs, err := s.inputs(ctx, prompt, images)
|
inputs, err := s.inputs(prompt, images)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||||
} else if len(inputs) == 0 {
|
} else if len(inputs) == 0 {
|
||||||
@@ -126,7 +128,6 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
// TODO(jessegross): Ingest cached history for grammar
|
// TODO(jessegross): Ingest cached history for grammar
|
||||||
|
|
||||||
return &Sequence{
|
return &Sequence{
|
||||||
ctx: ctx,
|
|
||||||
inputs: inputs,
|
inputs: inputs,
|
||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
@@ -145,31 +146,28 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||||
var inputs []input.Input
|
var inputs []input
|
||||||
var parts []string
|
var parts []string
|
||||||
var matches [][]string
|
var matches [][]string
|
||||||
|
|
||||||
multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor)
|
// TODO(jessegross): This can sometimes trigger for matching text in the
|
||||||
|
// user's prompt. We previously tried to avoid it by only looking for images
|
||||||
|
// on image models. We don't have a clear indication now but it would be better
|
||||||
|
// to properly escape it in any case.
|
||||||
|
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||||
|
parts = re.Split(prompt, -1)
|
||||||
|
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||||
|
|
||||||
if visionModel {
|
|
||||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
|
||||||
parts = re.Split(prompt, -1)
|
|
||||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
|
||||||
} else {
|
|
||||||
parts = []string{prompt}
|
|
||||||
}
|
|
||||||
|
|
||||||
postTokenize := false
|
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
// text - tokenize
|
// text - tokenize
|
||||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
tokens, err := s.model.(model.TextProcessor).Encode(part)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
inputs = append(inputs, input.Input{Token: t})
|
inputs = append(inputs, input{token: t})
|
||||||
}
|
}
|
||||||
|
|
||||||
// image - decode and store
|
// image - decode and store
|
||||||
@@ -188,25 +186,12 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
|
|||||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.multimodalHash.Reset()
|
inputs = append(inputs, input{image: image})
|
||||||
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
|
|
||||||
imageHash := s.multimodalHash.Sum64()
|
|
||||||
|
|
||||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
|
||||||
postTokenize = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if visionModel && postTokenize {
|
|
||||||
var err error
|
|
||||||
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,15 +236,8 @@ type Server struct {
|
|||||||
// KV cache
|
// KV cache
|
||||||
cache *InputCache
|
cache *InputCache
|
||||||
|
|
||||||
// multimodalHash generates hashes for comparing equality
|
// next sequence for prompt processing to avoid starvation
|
||||||
// of non-text data
|
nextSeq int
|
||||||
multimodalHash maphash.Hash
|
|
||||||
|
|
||||||
// vocab is a llama.cpp vocab required for gammar-based
|
|
||||||
// constrained generation (json mode, structured outputs)
|
|
||||||
// TODO: this is temporary until Ollama sampling supports
|
|
||||||
// constrained generation
|
|
||||||
vocab *sample.Vocab
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) allNil() bool {
|
func (s *Server) allNil() bool {
|
||||||
@@ -305,7 +283,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
|
|||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
seq.cache.InUse = false
|
seq.cache.InUse = false
|
||||||
seq.ctx.Close()
|
|
||||||
s.seqs[seqIndex] = nil
|
s.seqs[seqIndex] = nil
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
@@ -333,25 +310,30 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
var options input.Options
|
var options model.Options
|
||||||
|
imgSeq := -1
|
||||||
|
|
||||||
|
seqIdx := s.nextSeq - 1
|
||||||
|
for range s.seqs {
|
||||||
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||||
|
seq := s.seqs[seqIdx]
|
||||||
|
|
||||||
for i, seq := range s.seqs {
|
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// if past the num predict limit
|
// if past the num predict limit
|
||||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
s.removeSequence(i, "limit")
|
s.removeSequence(seqIdx, "limit")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||||
seq.cache.Inputs = []input.Input{}
|
seq.cache.Inputs = []input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
for j, inp := range seq.inputs {
|
for i, input := range seq.inputs {
|
||||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||||
if len(seq.pendingInputs) == 0 {
|
if len(seq.pendingInputs) == 0 {
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
@@ -363,23 +345,37 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if j >= s.batchSize {
|
if i >= s.batchSize {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
options.Inputs = append(options.Inputs, inp.Token)
|
// TODO(jessegross): Image inputs need to be rethought - it's
|
||||||
if inp.Multimodal != nil {
|
// it doesn't work well for different types of models or multiple sequences
|
||||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
if input.image != nil {
|
||||||
|
if len(seq.pendingInputs) != len(options.Images) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if imgSeq != seqIdx && imgSeq != -1 {
|
||||||
|
s.nextSeq = seqIdx
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
imgSeq = seqIdx
|
||||||
|
options.Images = append(options.Images, input.image)
|
||||||
|
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
options.Inputs = append(options.Inputs, input.token)
|
||||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(options.Outputs)
|
seq.iBatch = len(options.Outputs)
|
||||||
if j+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) {
|
||||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
||||||
}
|
}
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||||
@@ -407,7 +403,7 @@ func (s *Server) processBatch() error {
|
|||||||
// After calling Forward, pending inputs are now in the cache
|
// After calling Forward, pending inputs are now in the cache
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
seq.pendingInputs = []input.Input{}
|
seq.pendingInputs = []input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't sample prompt processing
|
// don't sample prompt processing
|
||||||
@@ -426,7 +422,6 @@ func (s *Server) processBatch() error {
|
|||||||
// if done processing the prompt, generate an embedding and return
|
// if done processing the prompt, generate an embedding and return
|
||||||
if seq.embeddingOnly {
|
if seq.embeddingOnly {
|
||||||
// TODO(jessegross): Embedding support
|
// TODO(jessegross): Embedding support
|
||||||
slog.Warn("generation of embedding outputs not yet supported")
|
|
||||||
s.removeSequence(i, "")
|
s.removeSequence(i, "")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -454,7 +449,7 @@ func (s *Server) processBatch() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := strings.Join(seq.pendingResponses, "")
|
||||||
@@ -580,30 +575,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var grammar *sample.Grammar
|
|
||||||
var err error
|
|
||||||
if req.Grammar != "" {
|
|
||||||
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sampler := sample.NewSampler(
|
|
||||||
req.Temperature,
|
|
||||||
req.TopK,
|
|
||||||
req.TopP,
|
|
||||||
req.MinP,
|
|
||||||
req.Seed,
|
|
||||||
grammar,
|
|
||||||
)
|
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
numPredict: req.NumPredict,
|
numPredict: req.NumPredict,
|
||||||
stop: req.Stop,
|
stop: req.Stop,
|
||||||
numKeep: int32(req.NumKeep),
|
numKeep: int32(req.NumKeep),
|
||||||
sampler: sampler,
|
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
|
||||||
embedding: false,
|
embedding: false,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -691,6 +667,65 @@ type EmbeddingResponse struct {
|
|||||||
Embedding []float32 `json:"embedding"`
|
Embedding []float32 `json:"embedding"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req EmbeddingRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
slog.Debug("embedding request", "content", req.Content)
|
||||||
|
|
||||||
|
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||||
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
slog.Info("aborting embeddings request due to client closing the connection")
|
||||||
|
} else {
|
||||||
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
found := false
|
||||||
|
for i, sq := range s.seqs {
|
||||||
|
if sq == nil {
|
||||||
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.seqs[i] = seq
|
||||||
|
s.cond.Signal()
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding := <-seq.embedding
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
|
||||||
|
Embedding: embedding,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type HealthResponse struct {
|
type HealthResponse struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Progress float32 `json:"progress"`
|
Progress float32 `json:"progress"`
|
||||||
@@ -751,7 +786,7 @@ func (s *Server) loadModel(
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.vocab = sample.NewVocab(mpath)
|
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
|
||||||
|
|
||||||
// TODO(jessegross): LoRA loading
|
// TODO(jessegross): LoRA loading
|
||||||
if lpath.String() != "" {
|
if lpath.String() != "" {
|
||||||
@@ -783,7 +818,7 @@ func Execute(args []string) error {
|
|||||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||||
@@ -828,6 +863,7 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): Parameters that need to be implemented:
|
// TODO(jessegross): Parameters that need to be implemented:
|
||||||
|
// flash-attn
|
||||||
// no-mmap
|
// no-mmap
|
||||||
// mlock
|
// mlock
|
||||||
|
|
||||||
@@ -842,11 +878,10 @@ func Execute(args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := ml.BackendParams{
|
params := ml.BackendParams{
|
||||||
NumThreads: *threads,
|
NumThreads: *threads,
|
||||||
NumGPULayers: *numGPULayers,
|
NumGPULayers: *numGPULayers,
|
||||||
MainGPU: *mainGPU,
|
MainGPU: *mainGPU,
|
||||||
TensorSplit: tensorSplitFloats,
|
TensorSplit: tensorSplitFloats,
|
||||||
FlashAttention: *flashAttention,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
server.ready.Add(1)
|
server.ready.Add(1)
|
||||||
@@ -868,13 +903,9 @@ func Execute(args []string) error {
|
|||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
// TODO: support embeddings
|
mux.HandleFunc("/embedding", server.embeddings)
|
||||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/completion", server.completion)
|
||||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
mux.HandleFunc("/health", server.health)
|
||||||
})
|
|
||||||
|
|
||||||
mux.HandleFunc("POST /completion", server.completion)
|
|
||||||
mux.HandleFunc("GET /health", server.health)
|
|
||||||
|
|
||||||
httpServer := http.Server{
|
httpServer := http.Server{
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
|
|||||||
@@ -3,223 +3,118 @@ package sample
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
|
||||||
"slices"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/llama"
|
"golang.org/x/exp/rand"
|
||||||
|
"gonum.org/v1/gonum/stat/sampleuv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// token represents information about a single token during sampling
|
type Sampler interface {
|
||||||
type token struct {
|
Sample([]float32) (int32, error)
|
||||||
id int32 // The token's unique identifier
|
|
||||||
value float32 // The raw logit or probability from the model
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Sampler struct {
|
type weighted struct {
|
||||||
rng *rand.Rand
|
src rand.Source
|
||||||
topK int
|
transforms []Transform
|
||||||
topP float32
|
|
||||||
minP float32
|
|
||||||
temperature float32
|
|
||||||
grammar *Grammar
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
|
||||||
tokens := make([]token, len(logits))
|
func Weighted(seed *uint64, transforms ...Transform) Sampler {
|
||||||
|
var src rand.Source
|
||||||
|
if seed != nil {
|
||||||
|
src = rand.NewSource(*seed)
|
||||||
|
}
|
||||||
|
return weighted{src: src, transforms: transforms}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s weighted) Sample(logits []float32) (int32, error) {
|
||||||
|
logits64 := make([]float64, len(logits))
|
||||||
|
for i, v := range logits {
|
||||||
|
logits64[i] = float64(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range s.transforms {
|
||||||
|
logits64 = t.Apply(logits64)
|
||||||
|
}
|
||||||
|
|
||||||
|
logitsCopy := make([]float64, 0, len(logits))
|
||||||
|
indices := make([]int, 0, len(logits))
|
||||||
|
for i, logit := range logits64 {
|
||||||
|
if !math.IsInf(logit, -1) {
|
||||||
|
logitsCopy = append(logitsCopy, logit)
|
||||||
|
indices = append(indices, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(logitsCopy) == 0 {
|
||||||
|
return -1, errors.New("no valid logits found for weighed sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
probs := softmax(logitsCopy)
|
||||||
|
w := sampleuv.NewWeighted(probs, s.src)
|
||||||
|
if idx, ok := w.Take(); ok {
|
||||||
|
return int32(indices[idx]), nil
|
||||||
|
}
|
||||||
|
return -1, errors.New("weighted sampler failed, no valid token found")
|
||||||
|
}
|
||||||
|
|
||||||
|
type greedy struct{}
|
||||||
|
|
||||||
|
func Greedy() Sampler {
|
||||||
|
return greedy{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sample returns the index of the maximum value in logits.
|
||||||
|
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||||
|
if len(logits) == 0 {
|
||||||
|
return -1, errors.New("no logits provided for greedy sampling")
|
||||||
|
}
|
||||||
|
|
||||||
|
maxIdx := 0
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
tokens[i].id = int32(i)
|
if logits[i] > logits[maxIdx] {
|
||||||
tokens[i].value = logits[i]
|
maxIdx = i
|
||||||
}
|
|
||||||
|
|
||||||
t, err := s.sample(tokens)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.grammar != nil {
|
|
||||||
// optimization: first check if the max logit is accepted by the grammar
|
|
||||||
// if the max logit is rejected, apply the grammar to all logits (slower)
|
|
||||||
top := []token{t}
|
|
||||||
s.grammar.Apply(top)
|
|
||||||
if !math.IsInf(float64(top[0].value), -1) {
|
|
||||||
s.grammar.Accept(top[0].id)
|
|
||||||
return top[0].id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// since .sample has side effects of modifying the tokens
|
|
||||||
// we need to reset them before applying the grammar and
|
|
||||||
// sampling again
|
|
||||||
for i := range logits {
|
|
||||||
tokens[i].id = int32(i)
|
|
||||||
tokens[i].value = logits[i]
|
|
||||||
}
|
|
||||||
s.grammar.Apply(tokens)
|
|
||||||
t, err = s.sample(tokens)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
s.grammar.Accept(t.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.id, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// greedy returns the highest probability token from the tokens
|
|
||||||
func greedy(tokens []token) token {
|
|
||||||
max := tokens[0]
|
|
||||||
for i := 1; i < len(tokens); i++ {
|
|
||||||
if tokens[i].value > max.value {
|
|
||||||
max = tokens[i]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return max
|
return int32(maxIdx), nil
|
||||||
}
|
|
||||||
|
|
||||||
// sample returns the highest probability token from the tokens
|
|
||||||
// given sampler parameters. It also has side effects of modifying the tokens
|
|
||||||
func (s *Sampler) sample(tokens []token) (token, error) {
|
|
||||||
if s.temperature == 0 {
|
|
||||||
return greedy(tokens), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// topK also sorts the tokens in descending order of logits
|
|
||||||
tokens = topK(tokens, s.topK)
|
|
||||||
|
|
||||||
tokens = temperature(tokens, s.temperature)
|
|
||||||
tokens = softmax(tokens)
|
|
||||||
|
|
||||||
tokens = topP(tokens, s.topP)
|
|
||||||
tokens = minP(tokens, s.minP)
|
|
||||||
|
|
||||||
// TODO: this should fall back to greedy sampling
|
|
||||||
// or topP, topK values etc should be such that
|
|
||||||
// there are always tokens to sample from
|
|
||||||
if len(tokens) == 0 {
|
|
||||||
return token{}, errors.New("no tokens to sample from")
|
|
||||||
}
|
|
||||||
|
|
||||||
var r float32
|
|
||||||
if s.rng != nil {
|
|
||||||
r = s.rng.Float32()
|
|
||||||
} else {
|
|
||||||
r = rand.Float32()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate cumulative sum of probabilities
|
|
||||||
var sum float32
|
|
||||||
for i := range tokens {
|
|
||||||
sum += tokens[i].value
|
|
||||||
tokens[i].value = sum
|
|
||||||
}
|
|
||||||
r *= tokens[len(tokens)-1].value
|
|
||||||
|
|
||||||
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
|
|
||||||
if token.value < target {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
})
|
|
||||||
|
|
||||||
return tokens[idx], nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
|
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
||||||
var rng *rand.Rand
|
if temperature == 0 {
|
||||||
if seed != -1 {
|
return Greedy(), nil
|
||||||
// PCG requires two parameters: sequence and stream
|
|
||||||
// Use original seed for sequence
|
|
||||||
sequence := uint64(seed)
|
|
||||||
// Use golden ratio hash to generate statistically independent seeds
|
|
||||||
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
|
|
||||||
}
|
|
||||||
if temperature < 0.0 {
|
|
||||||
temperature = 0.0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if topP < 0.0 {
|
if temperature < 0 || temperature > 2 {
|
||||||
topP = 0.0
|
return nil, errors.New("temperature must be between 0 and 2")
|
||||||
}
|
|
||||||
if topP >= 1.0 {
|
|
||||||
topP = 1.0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if minP < 0.0 {
|
transforms := []Transform{Temperature(temperature)}
|
||||||
minP = 0.0
|
|
||||||
}
|
|
||||||
if minP >= 1.0 {
|
|
||||||
minP = 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
return Sampler{
|
if topK != 0 {
|
||||||
rng: rng,
|
if topK <= 0 {
|
||||||
topK: topK,
|
return nil, errors.New("topK must be greater than 0")
|
||||||
topP: topP,
|
|
||||||
minP: minP,
|
|
||||||
temperature: temperature,
|
|
||||||
grammar: grammar,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Grammar struct {
|
|
||||||
vocab *Vocab
|
|
||||||
grammar string
|
|
||||||
sampler *llama.Sampler
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
|
|
||||||
v, err := vocab.Load()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Grammar{
|
|
||||||
vocab: vocab,
|
|
||||||
grammar: grammar,
|
|
||||||
sampler: llama.NewGrammarSampler(v, grammar),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Grammar) Apply(tokens []token) {
|
|
||||||
tds := make([]llama.TokenData, len(tokens))
|
|
||||||
for i, token := range tokens {
|
|
||||||
tds[i].Id = token.id
|
|
||||||
tds[i].Logit = token.value
|
|
||||||
}
|
|
||||||
|
|
||||||
g.sampler.Apply(tds)
|
|
||||||
|
|
||||||
for i := range tokens {
|
|
||||||
tokens[i].value = tds[i].Logit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *Grammar) Accept(token int32) {
|
|
||||||
g.sampler.Accept(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Vocab struct {
|
|
||||||
once sync.Once
|
|
||||||
vocab *llama.Vocab
|
|
||||||
err error
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewVocab(path string) *Vocab {
|
|
||||||
return &Vocab{path: path}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load returns the lazily-loaded vocabulary
|
|
||||||
func (v *Vocab) Load() (*llama.Vocab, error) {
|
|
||||||
v.once.Do(func() {
|
|
||||||
vocab, err := llama.LoadVocabFromFile(v.path)
|
|
||||||
if err != nil {
|
|
||||||
v.err = err
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
v.vocab = vocab
|
transforms = append(transforms, TopK(topK))
|
||||||
})
|
}
|
||||||
return v.vocab, v.err
|
|
||||||
|
if topP != 0 {
|
||||||
|
if topP < 0 || topP >= 1 {
|
||||||
|
return nil, errors.New("topP must be between 0 and 1")
|
||||||
|
}
|
||||||
|
transforms = append(transforms, TopP(topP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if minP != 0 {
|
||||||
|
if minP < 0 || minP >= 1 {
|
||||||
|
return nil, errors.New("minP must be between 0 and 1")
|
||||||
|
}
|
||||||
|
transforms = append(transforms, MinP(minP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if seed >= 0 {
|
||||||
|
seed64 := uint64(seed)
|
||||||
|
return Weighted(&seed64, transforms...), nil
|
||||||
|
}
|
||||||
|
return Weighted(nil, transforms...), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
package sample
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func BenchmarkWeightedSampler(b *testing.B) {
|
|
||||||
sizes := []int{10, 100, 1000, 10000}
|
|
||||||
|
|
||||||
for _, size := range sizes {
|
|
||||||
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
|
|
||||||
logits := make([]float32, size)
|
|
||||||
for i := range logits {
|
|
||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
|
||||||
}
|
|
||||||
|
|
||||||
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
sampler.Sample(logits)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
configs := []struct {
|
|
||||||
name string
|
|
||||||
temperature float32
|
|
||||||
topK int
|
|
||||||
topP float32
|
|
||||||
minP float32
|
|
||||||
seed int
|
|
||||||
}{
|
|
||||||
{"Greedy", 0, -1, 0, 0, -1},
|
|
||||||
{"Temperature", 0.8, -1, 0, 0, -1},
|
|
||||||
{"TopK", 0.8, 50, 0, 0, -1},
|
|
||||||
{"TopP", 0.8, -1, 0.9, 0, -1},
|
|
||||||
{"MinP", 0.8, -1, 0, 0.05, -1},
|
|
||||||
{"WithSeed", 0.8, 50, 0, 0, 42},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fixed size for common vocab size
|
|
||||||
size := 128000
|
|
||||||
logits := make([]float32, size)
|
|
||||||
for i := range logits {
|
|
||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range configs {
|
|
||||||
b.Run("Config"+tc.name, func(b *testing.B) {
|
|
||||||
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
|
|
||||||
sampler.Sample(logits)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for b.Loop() {
|
|
||||||
sampler.Sample(logits)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test with combined transforms separately - topK influences performance greatly
|
|
||||||
b.Run("TransformCombined", func(b *testing.B) {
|
|
||||||
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for b.Loop() {
|
|
||||||
sampler.Sample(logits)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkGreedySampler(b *testing.B) {
|
|
||||||
sizes := []int{10, 100, 1000, 10000, 100000}
|
|
||||||
|
|
||||||
for _, size := range sizes {
|
|
||||||
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
|
|
||||||
logits := make([]float32, size)
|
|
||||||
for i := range logits {
|
|
||||||
logits[i] = float32(rand.Float64()*10 - 5)
|
|
||||||
}
|
|
||||||
|
|
||||||
sampler := NewSampler(0, -1, 0, 0, -1, nil)
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for b.Loop() {
|
|
||||||
sampler.Sample(logits)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWeighted(t *testing.T) {
|
func TestWeighted(t *testing.T) {
|
||||||
logits := []float32{-10, 3, -10, -10}
|
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||||
sampler := NewSampler(0, 0, 0, 0, 0, nil)
|
|
||||||
got, err := sampler.Sample(logits)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
@@ -18,26 +19,194 @@ func TestWeighted(t *testing.T) {
|
|||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
logits = []float32{-100, -10, 0, 10}
|
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||||
sampler = NewSampler(0, 0, 0, 0, 0, nil)
|
if err == nil {
|
||||||
got, err = sampler.Sample(logits)
|
t.Error("expected error for no valid tokens, got index", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
seed := uint64(42)
|
||||||
|
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
want = int32(3) // Should pick highest probability with this r value
|
// With seed 42, we expect a consistent sample
|
||||||
|
want = int32(3) // This will be deterministic due to the seed
|
||||||
if want != got {
|
if want != got {
|
||||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkSample(b *testing.B) {
|
type testTransform struct {
|
||||||
samplers := map[string]Sampler{
|
id int
|
||||||
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
|
callOrder *[]int
|
||||||
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
|
}
|
||||||
|
|
||||||
|
func (ts *testTransform) Apply(logits []float64) []float64 {
|
||||||
|
if ts.callOrder != nil {
|
||||||
|
*ts.callOrder = append(*ts.callOrder, ts.id)
|
||||||
|
}
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSample(t *testing.T) {
|
||||||
|
input := []float32{1, 2, 3, 4}
|
||||||
|
|
||||||
|
var callOrder []int
|
||||||
|
mock1 := &testTransform{
|
||||||
|
id: 1,
|
||||||
|
callOrder: &callOrder,
|
||||||
|
}
|
||||||
|
mock2 := &testTransform{
|
||||||
|
id: 2,
|
||||||
|
callOrder: &callOrder,
|
||||||
|
}
|
||||||
|
mock3 := &testTransform{
|
||||||
|
id: 3,
|
||||||
|
callOrder: &callOrder,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wantOrder := []int{1, 2, 3}
|
||||||
|
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||||
|
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSampler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
temperature float32
|
||||||
|
topK int
|
||||||
|
topP float32
|
||||||
|
minP float32
|
||||||
|
seed int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no transforms",
|
||||||
|
// temperature is 0, so greedy should be used
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "temperature",
|
||||||
|
temperature: 0.5,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid temperature negative",
|
||||||
|
temperature: -1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid temperature too high",
|
||||||
|
temperature: 2.1,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top k",
|
||||||
|
topK: 10,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid top k negative",
|
||||||
|
topK: -1,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "top p",
|
||||||
|
topP: 0.9,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid top p negative",
|
||||||
|
topP: -0.1,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid top p one",
|
||||||
|
topP: 1.0,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "min p",
|
||||||
|
minP: 0.2,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid min p negative",
|
||||||
|
minP: -0.1,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid min p one",
|
||||||
|
minP: 1.0,
|
||||||
|
temperature: 0.8,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default values",
|
||||||
|
temperature: 0.8,
|
||||||
|
topK: 40,
|
||||||
|
topP: 0.9,
|
||||||
|
minP: 0.0,
|
||||||
|
seed: 0,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all zeroes",
|
||||||
|
temperature: 0.0,
|
||||||
|
topK: 0,
|
||||||
|
topP: 0.0,
|
||||||
|
minP: 0.0,
|
||||||
|
seed: 0,
|
||||||
|
wantErr: false, // all zeroes means no transforms
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all transforms",
|
||||||
|
temperature: 0.8,
|
||||||
|
topK: 50,
|
||||||
|
topP: 0.95,
|
||||||
|
minP: 0.1,
|
||||||
|
seed: 42,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSample(b *testing.B) {
|
||||||
|
transforms := []Transform{
|
||||||
|
Temperature(0.5),
|
||||||
|
TopK(10),
|
||||||
|
TopP(0.9),
|
||||||
|
MinP(0.2),
|
||||||
|
}
|
||||||
|
|
||||||
|
samplers := map[string]Sampler{
|
||||||
|
"Greedy": Greedy(),
|
||||||
|
"Weighted": Weighted(nil, transforms...),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate random logits for benchmarking
|
|
||||||
logits := make([]float32, 1<<16)
|
logits := make([]float32, 1<<16)
|
||||||
for i := range logits {
|
for i := range logits {
|
||||||
logits[i] = rand.Float32()
|
logits[i] = rand.Float32()
|
||||||
@@ -46,9 +215,9 @@ func BenchmarkSample(b *testing.B) {
|
|||||||
for name, s := range samplers {
|
for name, s := range samplers {
|
||||||
b.Run(name, func(b *testing.B) {
|
b.Run(name, func(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for b.Loop() {
|
for range b.N {
|
||||||
if _, err := s.Sample(logits); err != nil {
|
if _, err := s.Sample(logits); err != nil {
|
||||||
b.Fatalf("error sampling: %v", err)
|
b.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,145 +1,120 @@
|
|||||||
package sample
|
package sample
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"container/heap"
|
"cmp"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
|
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||||
)
|
)
|
||||||
|
|
||||||
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
|
type Transform interface {
|
||||||
type tokenHeap []token
|
Apply([]float64) []float64
|
||||||
|
|
||||||
func (h tokenHeap) Len() int { return len(h) }
|
|
||||||
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
|
|
||||||
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
|
||||||
|
|
||||||
func (h *tokenHeap) Push(x any) {
|
|
||||||
*h = append(*h, x.(token))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *tokenHeap) Pop() any {
|
// TODO(parthsareen): potentially cache softmax values
|
||||||
old := *h
|
func softmax(logits []float64) []float64 {
|
||||||
n := len(old)
|
var sum float64
|
||||||
x := old[n-1]
|
probs := make([]float64, len(logits))
|
||||||
*h = old[0 : n-1]
|
for i, v := range logits {
|
||||||
return x
|
probs[i] = math.Exp(v)
|
||||||
}
|
sum += probs[i]
|
||||||
|
|
||||||
// temperature applies scaling to the logits
|
|
||||||
func temperature(ts []token, temp float32) []token {
|
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
|
||||||
temp = max(temp, 1e-7)
|
|
||||||
for i := range ts {
|
|
||||||
ts[i].value = ts[i].value / temp
|
|
||||||
}
|
}
|
||||||
return ts
|
|
||||||
|
for i := range probs {
|
||||||
|
probs[i] /= sum
|
||||||
|
}
|
||||||
|
|
||||||
|
return probs
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax applies normalization to the logits
|
type Temperature float64
|
||||||
func softmax(ts []token) []token {
|
|
||||||
// Find max logit for numerical stability
|
func (t Temperature) Apply(logits []float64) []float64 {
|
||||||
maxLogit := float32(math.Inf(-1))
|
temp := math.Max(float64(t), 1e-7)
|
||||||
for _, t := range ts {
|
|
||||||
if t.value > maxLogit {
|
// subtracting max logit to avoid under/overflow
|
||||||
maxLogit = t.value
|
maxLogit := slices.Max(logits)
|
||||||
|
for i := range logits {
|
||||||
|
logits[i] = (logits[i] - maxLogit) / temp
|
||||||
|
}
|
||||||
|
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
|
||||||
|
type logitMap struct {
|
||||||
|
index int
|
||||||
|
logit float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type TopK int
|
||||||
|
|
||||||
|
// TODO(parthsareen): avoid having to check all logits after this transform
|
||||||
|
func (k TopK) Apply(logits []float64) []float64 {
|
||||||
|
if int(k) >= len(logits) {
|
||||||
|
return logits
|
||||||
|
}
|
||||||
|
q := pq.NewWith(func(a, b logitMap) int {
|
||||||
|
return -cmp.Compare(a.logit, b.logit)
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, logit := range logits {
|
||||||
|
q.Enqueue(logitMap{index: i, logit: logit})
|
||||||
|
}
|
||||||
|
|
||||||
|
validLogits := make(map[int]float64)
|
||||||
|
for range k {
|
||||||
|
logitMap, _ := q.Dequeue()
|
||||||
|
validLogits[logitMap.index] = logitMap.logit
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range logits {
|
||||||
|
if _, ok := validLogits[i]; !ok {
|
||||||
|
logits[i] = math.Inf(-1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute exp(x - max)
|
return logits
|
||||||
var sum float32
|
|
||||||
for i, v := range ts {
|
|
||||||
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
|
|
||||||
sum += ts[i].value
|
|
||||||
}
|
|
||||||
|
|
||||||
// exp(x - max) / sum(exp(x - max))
|
|
||||||
for i := range ts {
|
|
||||||
ts[i].value /= sum
|
|
||||||
}
|
|
||||||
|
|
||||||
return ts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// topK limits the number of tokens considered to the k highest logits
|
type TopP float64
|
||||||
func topK(ts []token, k int) []token {
|
|
||||||
if k >= len(ts) || k <= 0 {
|
func (p TopP) Apply(logits []float64) []float64 {
|
||||||
slices.SortFunc(ts, func(a, b token) int {
|
probs := softmax(logits)
|
||||||
switch {
|
indices := make([]int, len(probs))
|
||||||
case a.value < b.value:
|
for i := range indices {
|
||||||
return 1
|
indices[i] = i
|
||||||
case a.value > b.value:
|
}
|
||||||
return -1
|
|
||||||
default:
|
// sort in descending order
|
||||||
return 0
|
slices.SortFunc(indices, func(i, j int) int {
|
||||||
|
return cmp.Compare(probs[j], probs[i])
|
||||||
|
})
|
||||||
|
|
||||||
|
var sum float64
|
||||||
|
for i, idx := range indices {
|
||||||
|
sum += probs[idx]
|
||||||
|
if sum > float64(p) {
|
||||||
|
for _, idx := range indices[i+1:] {
|
||||||
|
logits[idx] = math.Inf(-1)
|
||||||
}
|
}
|
||||||
})
|
break
|
||||||
return ts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize min-heap with first k elements
|
|
||||||
h := make(tokenHeap, k)
|
|
||||||
copy(h, ts[:k])
|
|
||||||
heap.Init(&h)
|
|
||||||
|
|
||||||
// Process remaining elements
|
|
||||||
for i := k; i < len(ts); i++ {
|
|
||||||
if ts[i].value > h[0].value {
|
|
||||||
heap.Pop(&h)
|
|
||||||
heap.Push(&h, ts[i])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return logits
|
||||||
// Convert heap to sorted slice in descending order
|
|
||||||
result := make([]token, len(h))
|
|
||||||
for i := k - 1; i >= 0; i-- {
|
|
||||||
result[i] = heap.Pop(&h).(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// topP limits tokens to those with cumulative probability p
|
type MinP float64
|
||||||
func topP(ts []token, p float32) []token {
|
|
||||||
if p == 1.0 {
|
|
||||||
return ts
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find cutoff index where cumulative sum exceeds p
|
func (p MinP) Apply(logits []float64) []float64 {
|
||||||
var sum float32
|
probs := softmax(logits)
|
||||||
for i, t := range ts {
|
threshold := slices.Max(probs) * float64(p)
|
||||||
sum += t.value
|
|
||||||
if sum > float32(p) {
|
for i, prob := range probs {
|
||||||
ts = ts[:i+1]
|
if prob < threshold {
|
||||||
return ts
|
logits[i] = math.Inf(-1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ts
|
return logits
|
||||||
}
|
|
||||||
|
|
||||||
// minP limits tokens to those with cumulative probability p
|
|
||||||
func minP(ts []token, p float32) []token {
|
|
||||||
if p == 1.0 {
|
|
||||||
return ts
|
|
||||||
}
|
|
||||||
|
|
||||||
maxProb := float32(math.Inf(-1))
|
|
||||||
for _, token := range ts {
|
|
||||||
if token.value > maxProb {
|
|
||||||
maxProb = token.value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threshold := maxProb * float32(p)
|
|
||||||
|
|
||||||
// Filter tokens in-place
|
|
||||||
validTokens := ts[:0]
|
|
||||||
for i, token := range ts {
|
|
||||||
if token.value >= threshold {
|
|
||||||
validTokens = append(validTokens, ts[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ts = validTokens
|
|
||||||
return ts
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,258 +4,77 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Helper to convert float32 slice to logit slice
|
|
||||||
func toTokens(values []float32) []token {
|
|
||||||
tokens := make([]token, len(values))
|
|
||||||
for i, v := range values {
|
|
||||||
tokens[i] = token{
|
|
||||||
id: int32(i),
|
|
||||||
value: v,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper to compare logit slices
|
|
||||||
func compareLogits(t *testing.T, name string, want []float32, got []token) {
|
|
||||||
t.Helper()
|
|
||||||
if len(want) != len(got) {
|
|
||||||
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for i := range want {
|
|
||||||
if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
|
|
||||||
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTemperature(t *testing.T) {
|
func TestTemperature(t *testing.T) {
|
||||||
input := []float32{1.0, 4.0, -2.0, 0.0}
|
got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
|
||||||
got := temperature(toTokens(input), 0.5)
|
want := []float64{-4, -10, 0, -14, -6, -12, -8}
|
||||||
want := []float32{2.0, 8.0, -4.0, 0.0}
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
compareLogits(t, "temperature(0.5)", want, got)
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
got = temperature(toTokens(input), 1.0)
|
|
||||||
want = []float32{1.0, 4.0, -2.0, 0.0}
|
|
||||||
compareLogits(t, "temperature(1)", want, got)
|
|
||||||
|
|
||||||
got = temperature(toTokens(input), 0.0)
|
|
||||||
want = []float32{1e7, 4e7, -2e7, 0.0}
|
|
||||||
compareLogits(t, "temperature(0)", want, got)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSoftmax(t *testing.T) {
|
func TestSoftmax(t *testing.T) {
|
||||||
tests := []struct {
|
got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
name string
|
|
||||||
input []float32
|
|
||||||
expected []float32
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "correctness softmax",
|
|
||||||
input: []float32{1, -2, 3, 0},
|
|
||||||
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "normal distribution",
|
|
||||||
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single value",
|
|
||||||
input: []float32{1.0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "identical values",
|
|
||||||
input: []float32{0.9, 0.9, 0.9},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large values",
|
|
||||||
input: []float32{1000.0, 2000.0, 3000.0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small values",
|
|
||||||
input: []float32{1e-6, 2e-6, 3e-6},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative values",
|
|
||||||
input: []float32{-1.0, -2.0, -3.0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed values",
|
|
||||||
input: []float32{-100.0, 0.0, 100.0},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
got := softmax(toTokens(tt.input))
|
t.Errorf("probs mismatch (-want +got):\n%s", diff)
|
||||||
|
|
||||||
if tt.expected != nil {
|
|
||||||
compareLogits(t, tt.name, tt.expected, got)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check probabilities sum to 1
|
|
||||||
var sum float32
|
|
||||||
for _, token := range got {
|
|
||||||
sum += token.value
|
|
||||||
if token.value < 0 || token.value > 1 {
|
|
||||||
t.Errorf("probability out of range [0,1]: got %f", token.value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if math.Abs(float64(sum-1.0)) > 1e-6 {
|
|
||||||
t.Errorf("probabilities don't sum to 1: got %f", sum)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTopK(t *testing.T) {
|
func TestTopK(t *testing.T) {
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
|
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
|
||||||
// Test k=5
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
got := topK(toTokens(input), 5)
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
if len(got) != 5 {
|
|
||||||
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
|
|
||||||
}
|
|
||||||
// Should keep highest 3 values in descending order
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
|
|
||||||
compareLogits(t, "topK(3)", want, got)
|
|
||||||
|
|
||||||
got = topK(toTokens(input), 20)
|
|
||||||
if len(got) != len(input) {
|
|
||||||
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test k=-1
|
got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
|
||||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
|
||||||
got = topK(toTokens(input), -1)
|
|
||||||
if len(got) != len(input) {
|
|
||||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
|
||||||
}
|
|
||||||
compareLogits(t, "topK(-1)", want, got)
|
|
||||||
|
|
||||||
// Test k=0
|
want = []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||||
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
got = topK(toTokens(input), 0)
|
|
||||||
if len(got) != len(input) {
|
|
||||||
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
|
|
||||||
}
|
}
|
||||||
compareLogits(t, "topK(-1)", want, got)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTopP(t *testing.T) {
|
func TestTopP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4}
|
got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||||
tokens := toTokens(input)
|
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
// First apply temperature and softmax to get probabilities
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
tokens = softmax(tokens)
|
|
||||||
tokens = topK(tokens, 20)
|
|
||||||
|
|
||||||
// Then apply topP
|
|
||||||
got := topP(tokens, 0.95)
|
|
||||||
|
|
||||||
// Should keep tokens until cumsum > 0.95
|
|
||||||
if len(got) > 3 {
|
|
||||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
|
||||||
t.Logf("got: %v", got)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMinP(t *testing.T) {
|
func TestMinP(t *testing.T) {
|
||||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
|
||||||
tokens := toTokens(input)
|
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
|
||||||
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
// First apply temperature and softmax
|
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||||
tokens = softmax(tokens)
|
|
||||||
|
|
||||||
// Then apply minP
|
|
||||||
got := minP(tokens, 0.2)
|
|
||||||
|
|
||||||
// Should keep tokens with prob >= 0.2 * max_prob
|
|
||||||
if len(got) > 3 {
|
|
||||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSortLogits(t *testing.T) {
|
func BenchmarkTransform(b *testing.B) {
|
||||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
transforms := map[string]Transform{
|
||||||
tokens := toTokens(input)
|
"Temperature": Temperature(0.5),
|
||||||
|
"TopK": TopK(10),
|
||||||
tokens = topK(tokens, 20)
|
"TopP": TopP(0.9),
|
||||||
|
"MinP": MinP(0.2),
|
||||||
for i := 1; i < len(tokens); i++ {
|
|
||||||
if tokens[i].value > tokens[i-1].value {
|
|
||||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
|
||||||
i, tokens[i].value, tokens[i-1].value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
logits := make([]float64, 1<<16)
|
||||||
compareLogits(t, "sortLogits", want, tokens)
|
for i := range logits {
|
||||||
}
|
logits[i] = rand.Float64()
|
||||||
|
|
||||||
func BenchmarkTransforms(b *testing.B) {
|
|
||||||
// Generate random logits
|
|
||||||
tokens := make([]token, 1<<16)
|
|
||||||
for i := range tokens {
|
|
||||||
tokens[i] = token{
|
|
||||||
id: int32(i),
|
|
||||||
value: rand.Float32(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokensCopy := make([]token, len(tokens))
|
for name, transform := range transforms {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
b.Run("Temperature", func(b *testing.B) {
|
b.ResetTimer()
|
||||||
b.ResetTimer()
|
for range b.N {
|
||||||
for b.Loop() {
|
transform.Apply(logits)
|
||||||
copy(tokensCopy, tokens)
|
}
|
||||||
temperature(tokensCopy, 0.5)
|
})
|
||||||
}
|
}
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("Softmax", func(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
copy(tokensCopy, tokens)
|
|
||||||
softmax(tokensCopy)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("TopK", func(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
copy(tokensCopy, tokens)
|
|
||||||
topK(tokensCopy, 10)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("TopP", func(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
copy(tokensCopy, tokens)
|
|
||||||
topP(tokensCopy, 0.9)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("MinP", func(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
copy(tokensCopy, tokens)
|
|
||||||
minP(tokensCopy, 0.2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("SortTokens", func(b *testing.B) {
|
|
||||||
b.ResetTimer()
|
|
||||||
for b.Loop() {
|
|
||||||
copy(tokensCopy, tokens)
|
|
||||||
topK(tokensCopy, 200000)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,14 +80,13 @@ function checkEnv() {
|
|||||||
|
|
||||||
|
|
||||||
function buildOllama() {
|
function buildOllama() {
|
||||||
mkdir -Force -path "${script:DIST_DIR}\"
|
|
||||||
if ($script:ARCH -ne "arm64") {
|
if ($script:ARCH -ne "arm64") {
|
||||||
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
|
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
|
||||||
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
|
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
|
||||||
|
|
||||||
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
|
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --build --preset CPU --config Release --parallel $script:JOBS
|
& cmake --build --preset CPU --parallel $script:JOBS
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component CPU --strip
|
& cmake --install build --component CPU --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
@@ -102,7 +101,7 @@ function buildOllama() {
|
|||||||
# to avoid 2022 (or newer) from being used as the default
|
# to avoid 2022 (or newer) from being used as the default
|
||||||
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
|
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
& cmake --build --preset "CUDA 11" --parallel $script:JOBS
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component "CUDA" --strip
|
& cmake --install build --component "CUDA" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
@@ -113,7 +112,7 @@ function buildOllama() {
|
|||||||
write-host "Building CUDA v12 backend libraries"
|
write-host "Building CUDA v12 backend libraries"
|
||||||
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
|
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
|
& cmake --build --preset "CUDA 12" --parallel $script:JOBS
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component "CUDA" --strip
|
& cmake --install build --component "CUDA" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
@@ -132,7 +131,7 @@ function buildOllama() {
|
|||||||
$env:HIPCXX=""
|
$env:HIPCXX=""
|
||||||
$env:HIP_PLATFORM=""
|
$env:HIP_PLATFORM=""
|
||||||
$env:CMAKE_PREFIX_PATH=""
|
$env:CMAKE_PREFIX_PATH=""
|
||||||
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
|
& cmake --build --preset "ROCm" --parallel $script:JOBS
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component "HIP" --strip
|
& cmake --install build --component "HIP" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
|||||||
@@ -77,12 +77,11 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then
|
|||||||
fi
|
fi
|
||||||
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR"
|
||||||
status "Downloading Linux ${ARCH} bundle"
|
status "Downloading Linux ${ARCH} bundle"
|
||||||
curl --fail --show-error --location --progress-bar \
|
curl --fail --show-error --location --progress-bar \
|
||||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
||||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||||
|
|
||||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||||
status "Making ollama accessible in the PATH in $BINDIR"
|
status "Making ollama accessible in the PATH in $BINDIR"
|
||||||
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"
|
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -35,7 +34,6 @@ var (
|
|||||||
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
||||||
errUnknownType = errors.New("unknown type")
|
errUnknownType = errors.New("unknown type")
|
||||||
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
||||||
errFilePath = errors.New("file path must be relative")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) CreateHandler(c *gin.Context) {
|
func (s *Server) CreateHandler(c *gin.Context) {
|
||||||
@@ -48,13 +46,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for v := range r.Files {
|
|
||||||
if !fs.ValidPath(v) {
|
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||||
@@ -113,7 +104,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||||||
if r.Adapters != nil {
|
if r.Adapters != nil {
|
||||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
|
||||||
if errors.Is(err, badReq) {
|
if errors.Is(err, badReq) {
|
||||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||||
return
|
return
|
||||||
@@ -230,22 +221,8 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer os.RemoveAll(tmpDir)
|
defer os.RemoveAll(tmpDir)
|
||||||
// Set up a root to validate paths
|
|
||||||
root, err := os.OpenRoot(tmpDir)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer root.Close()
|
|
||||||
|
|
||||||
for fp, digest := range files {
|
for fp, digest := range files {
|
||||||
if !fs.ValidPath(fp) {
|
|
||||||
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
|
|
||||||
}
|
|
||||||
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
// Path is likely outside the root
|
|
||||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
|
||||||
}
|
|
||||||
|
|
||||||
blobPath, err := GetBlobsPath(digest)
|
blobPath, err := GetBlobsPath(digest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -293,7 +270,6 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer bin.Close()
|
|
||||||
|
|
||||||
f, _, err := ggml.Decode(bin, 0)
|
f, _, err := ggml.Decode(bin, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,106 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConvertFromSafetensors(t *testing.T) {
|
|
||||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
|
||||||
|
|
||||||
// Helper function to create a new layer and return its digest
|
|
||||||
makeTemp := func(content string) string {
|
|
||||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create layer: %v", err)
|
|
||||||
}
|
|
||||||
return l.Digest
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a safetensors compatible file with empty JSON content
|
|
||||||
var buf bytes.Buffer
|
|
||||||
headerSize := int64(len("{}"))
|
|
||||||
binary.Write(&buf, binary.LittleEndian, headerSize)
|
|
||||||
buf.WriteString("{}")
|
|
||||||
|
|
||||||
model := makeTemp(buf.String())
|
|
||||||
config := makeTemp(`{
|
|
||||||
"architectures": ["LlamaForCausalLM"],
|
|
||||||
"vocab_size": 32000
|
|
||||||
}`)
|
|
||||||
tokenizer := makeTemp(`{
|
|
||||||
"version": "1.0",
|
|
||||||
"truncation": null,
|
|
||||||
"padding": null,
|
|
||||||
"added_tokens": [
|
|
||||||
{
|
|
||||||
"id": 0,
|
|
||||||
"content": "<|endoftext|>",
|
|
||||||
"single_word": false,
|
|
||||||
"lstrip": false,
|
|
||||||
"rstrip": false,
|
|
||||||
"normalized": false,
|
|
||||||
"special": true
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
filePath string
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
// Invalid
|
|
||||||
{
|
|
||||||
name: "InvalidRelativePathShallow",
|
|
||||||
filePath: filepath.Join("..", "file.safetensors"),
|
|
||||||
wantErr: errFilePath,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "InvalidRelativePathDeep",
|
|
||||||
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
|
||||||
wantErr: errFilePath,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "InvalidNestedPath",
|
|
||||||
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
|
||||||
wantErr: errFilePath,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "AbsolutePathOutsideRoot",
|
|
||||||
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
|
||||||
wantErr: errFilePath, // Should fail since it's outside tmpDir
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ValidRelativePath",
|
|
||||||
filePath: "model.safetensors",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Create the minimum required file map for convertFromSafetensors
|
|
||||||
files := map[string]string{
|
|
||||||
tt.filePath: model,
|
|
||||||
"config.json": config,
|
|
||||||
"tokenizer.json": tokenizer,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
|
||||||
|
|
||||||
if (tt.wantErr == nil && err != nil) ||
|
|
||||||
(tt.wantErr != nil && err == nil) ||
|
|
||||||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
|
||||||
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -27,7 +27,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -45,9 +44,9 @@ import (
|
|||||||
|
|
||||||
// Errors
|
// Errors
|
||||||
var (
|
var (
|
||||||
// ErrModelNotFound is returned when a manifest is not found in the
|
// ErrManifestNotFound is returned when a manifest is not found in the
|
||||||
// cache or registry.
|
// cache or registry.
|
||||||
ErrModelNotFound = errors.New("model not found")
|
ErrManifestNotFound = errors.New("manifest not found")
|
||||||
|
|
||||||
// ErrManifestInvalid is returned when a manifest found in a local or
|
// ErrManifestInvalid is returned when a manifest found in a local or
|
||||||
// remote cache is invalid.
|
// remote cache is invalid.
|
||||||
@@ -74,22 +73,19 @@ const (
|
|||||||
DefaultMaxChunkSize = 8 << 20
|
DefaultMaxChunkSize = 8 << 20
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
// DefaultCache returns a new disk cache for storing models. If the
|
||||||
|
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
||||||
|
// otherwise, it uses $HOME/.ollama/models.
|
||||||
|
func DefaultCache() (*blob.DiskCache, error) {
|
||||||
dir := os.Getenv("OLLAMA_MODELS")
|
dir := os.Getenv("OLLAMA_MODELS")
|
||||||
if dir == "" {
|
if dir == "" {
|
||||||
home, _ := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
home = cmp.Or(home, ".")
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
dir = filepath.Join(home, ".ollama", "models")
|
dir = filepath.Join(home, ".ollama", "models")
|
||||||
}
|
}
|
||||||
return blob.Open(dir)
|
return blob.Open(dir)
|
||||||
})
|
|
||||||
|
|
||||||
// DefaultCache returns the default cache used by the registry. It is
|
|
||||||
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
|
||||||
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
|
||||||
// it uses the current working directory.
|
|
||||||
func DefaultCache() (*blob.DiskCache, error) {
|
|
||||||
return defaultCache()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||||
@@ -114,18 +110,7 @@ type Error struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Error) Error() string {
|
func (e *Error) Error() string {
|
||||||
var b strings.Builder
|
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
|
||||||
b.WriteString("registry responded with status ")
|
|
||||||
b.WriteString(strconv.Itoa(e.Status))
|
|
||||||
if e.Code != "" {
|
|
||||||
b.WriteString(": code ")
|
|
||||||
b.WriteString(e.Code)
|
|
||||||
}
|
|
||||||
if e.Message != "" {
|
|
||||||
b.WriteString(": ")
|
|
||||||
b.WriteString(e.Message)
|
|
||||||
}
|
|
||||||
return b.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Error) LogValue() slog.Value {
|
func (e *Error) LogValue() slog.Value {
|
||||||
@@ -183,10 +168,6 @@ func CompleteName(name string) string {
|
|||||||
// Registry is a client for performing push and pull operations against an
|
// Registry is a client for performing push and pull operations against an
|
||||||
// Ollama registry.
|
// Ollama registry.
|
||||||
type Registry struct {
|
type Registry struct {
|
||||||
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
|
||||||
// used.
|
|
||||||
Cache *blob.DiskCache
|
|
||||||
|
|
||||||
// UserAgent is the User-Agent header to send with requests to the
|
// UserAgent is the User-Agent header to send with requests to the
|
||||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||||
UserAgent string
|
UserAgent string
|
||||||
@@ -225,28 +206,18 @@ type Registry struct {
|
|||||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||||
MaxChunkSize int64
|
MaxChunkSize int64
|
||||||
|
|
||||||
// Mask, if set, is the name used to convert non-fully qualified names
|
// Mask, if set, is the name used to convert non-fully qualified
|
||||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
// names to fully qualified names. If empty, the default mask
|
||||||
|
// ("registry.ollama.ai/library/_:latest") is used.
|
||||||
Mask string
|
Mask string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
func (r *Registry) completeName(name string) names.Name {
|
||||||
if r.Cache != nil {
|
|
||||||
return r.Cache, nil
|
|
||||||
}
|
|
||||||
return defaultCache()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
|
||||||
mask := defaultMask
|
mask := defaultMask
|
||||||
if r.Mask != "" {
|
if r.Mask != "" {
|
||||||
mask = names.Parse(r.Mask)
|
mask = names.Parse(r.Mask)
|
||||||
}
|
}
|
||||||
n := names.Merge(names.Parse(name), mask)
|
return names.Merge(names.Parse(name), mask)
|
||||||
if !n.IsFullyQualified() {
|
|
||||||
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||||
@@ -307,17 +278,12 @@ type PushParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Push pushes the model with the name in the cache to the remote registry.
|
// Push pushes the model with the name in the cache to the remote registry.
|
||||||
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
p = &PushParams{}
|
p = &PushParams{}
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := r.cache()
|
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -340,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|||||||
|
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
|
|
||||||
scheme, n, _, err := r.parseNameExtended(name)
|
scheme, n, _, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// This should never happen since ResolveLocal should have
|
// This should never happen since ResolveLocal should have
|
||||||
// already validated the name.
|
// already validated the name.
|
||||||
@@ -366,7 +332,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|||||||
n.Model(),
|
n.Model(),
|
||||||
l.Digest,
|
l.Digest,
|
||||||
)
|
)
|
||||||
res, err := r.send(ctx, "POST", startURL, nil)
|
res, err := r.doOK(ctx, "POST", startURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -390,7 +356,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|||||||
}
|
}
|
||||||
req.ContentLength = l.Size
|
req.ContentLength = l.Size
|
||||||
|
|
||||||
res, err = sendRequest(r.client(), req)
|
res, err = doOK(r.client(), req)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
}
|
}
|
||||||
@@ -410,7 +376,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|||||||
n.Model(),
|
n.Model(),
|
||||||
n.Tag(),
|
n.Tag(),
|
||||||
)
|
)
|
||||||
res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
|
res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
}
|
}
|
||||||
@@ -433,8 +399,8 @@ func canRetry(err error) bool {
|
|||||||
// chunks of the specified size, and then reassembled and verified. This is
|
// chunks of the specified size, and then reassembled and verified. This is
|
||||||
// typically slower than splitting the model up across layers, and is mostly
|
// typically slower than splitting the model up across layers, and is mostly
|
||||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||||
scheme, n, _, err := r.parseNameExtended(name)
|
scheme, n, _, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -447,11 +413,6 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := r.cache()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
exists := func(l *Layer) bool {
|
exists := func(l *Layer) bool {
|
||||||
info, err := c.Get(l.Digest)
|
info, err := c.Get(l.Digest)
|
||||||
return err == nil && info.Size == l.Size
|
return err == nil && info.Size == l.Size
|
||||||
@@ -459,15 +420,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
t := traceFromContext(ctx)
|
t := traceFromContext(ctx)
|
||||||
|
|
||||||
g, ctx := errgroup.WithContext(ctx)
|
var g errgroup.Group
|
||||||
g.SetLimit(r.maxStreams())
|
g.SetLimit(r.maxStreams())
|
||||||
|
|
||||||
layers := m.Layers
|
for _, l := range m.Layers {
|
||||||
if m.Config != nil && m.Config.Digest.IsValid() {
|
|
||||||
layers = append(layers, m.Config)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, l := range layers {
|
|
||||||
if exists(l) {
|
if exists(l) {
|
||||||
t.update(l, l.Size, ErrCached)
|
t.update(l, l.Size, ErrCached)
|
||||||
continue
|
continue
|
||||||
@@ -484,9 +440,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
if l.Size <= r.maxChunkingThreshold() {
|
if l.Size <= r.maxChunkingThreshold() {
|
||||||
g.Go(func() error {
|
g.Go(func() error {
|
||||||
// TODO(bmizerany): retry/backoff like below in
|
res, err := doOK(r.client(), req)
|
||||||
// the chunking case
|
|
||||||
res, err := sendRequest(r.client(), req)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -512,21 +466,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
// fire an initial request to get the final URL and
|
// fire an initial request to get the final URL and
|
||||||
// then use that URL for the chunk requests.
|
// then use that URL for the chunk requests.
|
||||||
req.Header.Set("Range", "bytes=0-0")
|
req.Header.Set("Range", "bytes=0-0")
|
||||||
res, err := sendRequest(r.client(), req)
|
res, err := doOK(r.client(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
req = res.Request.WithContext(req.Context())
|
req = res.Request.WithContext(req.Context())
|
||||||
|
|
||||||
wp := writerPool{size: r.maxChunkSize()}
|
streamNo := 0
|
||||||
|
tws := make([]*bufio.Writer, r.maxStreams()-1)
|
||||||
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
|
||||||
if ctx.Err() != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
ticket := q.Take()
|
ticket := q.Take()
|
||||||
|
bufIdx := streamNo % len(tws)
|
||||||
|
streamNo++
|
||||||
g.Go(func() (err error) {
|
g.Go(func() (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -540,18 +492,23 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := func() error {
|
err := func() error {
|
||||||
req := req.Clone(req.Context())
|
req := req.Clone(req.Context())
|
||||||
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
|
||||||
res, err := sendRequest(r.client(), req)
|
res, err := doOK(r.client(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
tw := wp.get()
|
tw := tws[bufIdx]
|
||||||
|
if tw == nil {
|
||||||
|
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
|
||||||
|
tws[bufIdx] = tw
|
||||||
|
}
|
||||||
tw.Reset(ticket)
|
tw.Reset(ticket)
|
||||||
defer wp.put(tw)
|
defer tw.Reset(nil) // release ticket
|
||||||
|
|
||||||
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
_, err = io.CopyN(tw, res.Body, chunk.Size())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -593,14 +550,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
|||||||
|
|
||||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||||
// before attempting to unlink the model.
|
// before attempting to unlink the model.
|
||||||
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||||
n, err := r.parseName(name)
|
n := r.completeName(name)
|
||||||
if err != nil {
|
if !n.IsFullyQualified() {
|
||||||
return false, err
|
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||||
}
|
|
||||||
c, err := r.cache()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
}
|
||||||
return c.Unlink(n.String())
|
return c.Unlink(n.String())
|
||||||
}
|
}
|
||||||
@@ -610,9 +563,6 @@ type Manifest struct {
|
|||||||
Name string `json:"-"` // the canonical name of the model
|
Name string `json:"-"` // the canonical name of the model
|
||||||
Data []byte `json:"-"` // the raw data of the manifest
|
Data []byte `json:"-"` // the raw data of the manifest
|
||||||
Layers []*Layer `json:"layers"`
|
Layers []*Layer `json:"layers"`
|
||||||
|
|
||||||
// For legacy reasons, we still have to download the config layer.
|
|
||||||
Config *Layer `json:"config"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
|
||||||
@@ -676,18 +626,14 @@ type Layer struct {
|
|||||||
Size int64 `json:"size"`
|
Size int64 `json:"size"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||||
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
// parsed using [names.Split] but the scheme is ignored.
|
||||||
_, n, d, err := r.parseNameExtended(name)
|
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||||
if err != nil {
|
_, n, d, err := parseName(name, r.Mask)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c, err := r.cache()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !d.IsValid() {
|
if !d.IsValid() {
|
||||||
// No digest, so resolve the manifest by name.
|
|
||||||
d, err = c.Resolve(n.String())
|
d, err = c.Resolve(n.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -696,7 +642,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
|||||||
data, err := os.ReadFile(c.GetFile(d))
|
data, err := os.ReadFile(c.GetFile(d))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
|
return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -709,7 +655,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
|||||||
|
|
||||||
// Resolve resolves a name to a Manifest in the remote registry.
|
// Resolve resolves a name to a Manifest in the remote registry.
|
||||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||||
scheme, n, d, err := r.parseNameExtended(name)
|
scheme, n, d, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -719,7 +665,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
|
|||||||
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
|
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := r.send(ctx, "GET", manifestURL, nil)
|
res, err := r.doOK(ctx, "GET", manifestURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -744,7 +690,7 @@ func (r *Registry) client() *http.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newRequest constructs a new request, ready to use, with the given method,
|
// newRequest constructs a new request, ready to use, with the given method,
|
||||||
// url, and body, pre-signed with client [Key] and [UserAgent].
|
// url, and body, presigned with client Key and UserAgent.
|
||||||
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -763,17 +709,11 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendRequest makes a request with the given client and request, and returns the
|
// doOK makes a request with the given client and request, and returns the
|
||||||
// response if the status code is 200. If the status code is not 200, an Error
|
// response if the status code is 200. If the status code is not 200, an Error
|
||||||
// is parsed from the response body and returned. If any other error occurs, it
|
// is parsed from the response body and returned. If any other error occurs, it
|
||||||
// is returned.
|
// is returned.
|
||||||
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
|
func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("request error %s: %w", r.URL, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if r.URL.Scheme == "https+insecure" {
|
if r.URL.Scheme == "https+insecure" {
|
||||||
// TODO(bmizerany): clone client.Transport, set
|
// TODO(bmizerany): clone client.Transport, set
|
||||||
// InsecureSkipVerify, etc.
|
// InsecureSkipVerify, etc.
|
||||||
@@ -816,26 +756,20 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
|
|||||||
// Use the raw body if we can't parse it as an error object.
|
// Use the raw body if we can't parse it as an error object.
|
||||||
re.Message = string(out)
|
re.Message = string(out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
|
|
||||||
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
|
|
||||||
return nil, ErrModelNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
re.Status = res.StatusCode
|
re.Status = res.StatusCode
|
||||||
return nil, &re
|
return nil, &re
|
||||||
}
|
}
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// send is a convenience method for making a request with newRequest and
|
// doOK is a convenience method for making a request with newRequest and
|
||||||
// passing it to send with r.client().
|
// passing it to doOK with r.client().
|
||||||
func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
||||||
req, err := r.newRequest(ctx, method, path, body)
|
req, err := r.newRequest(ctx, method, path, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return sendRequest(r.client(), req)
|
return doOK(r.client(), req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeAuthToken creates an Ollama auth token for the given private key.
|
// makeAuthToken creates an Ollama auth token for the given private key.
|
||||||
@@ -925,7 +859,7 @@ var supportedSchemes = []string{
|
|||||||
|
|
||||||
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||||
|
|
||||||
// parseNameExtended parses and validates an extended name, returning the scheme, name,
|
// parseName parses and validates an extended name, returning the scheme, name,
|
||||||
// and digest.
|
// and digest.
|
||||||
//
|
//
|
||||||
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||||
@@ -936,8 +870,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
|
|||||||
//
|
//
|
||||||
// If the name is not, once merged with the mask, fully qualified,
|
// If the name is not, once merged with the mask, fully qualified,
|
||||||
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||||
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||||
scheme, name, digest := splitExtended(s)
|
scheme, name, digest := names.Split(s)
|
||||||
scheme = cmp.Or(scheme, "https")
|
scheme = cmp.Or(scheme, "https")
|
||||||
if !slices.Contains(supportedSchemes, scheme) {
|
if !slices.Contains(supportedSchemes, scheme) {
|
||||||
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||||
@@ -960,58 +894,13 @@ func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := r.parseName(name)
|
maskName := defaultMask
|
||||||
if err != nil {
|
if mask != "" {
|
||||||
return "", names.Name{}, blob.Digest{}, err
|
maskName = names.Parse(mask)
|
||||||
|
}
|
||||||
|
n := names.Merge(names.Parse(name), maskName)
|
||||||
|
if !n.IsFullyQualified() {
|
||||||
|
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
|
||||||
}
|
}
|
||||||
return scheme, n, d, nil
|
return scheme, n, d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// splitExtended splits an extended name string into its scheme, name, and digest
|
|
||||||
// parts.
|
|
||||||
//
|
|
||||||
// Examples:
|
|
||||||
//
|
|
||||||
// http://ollama.com/bmizerany/smol:latest@digest
|
|
||||||
// https://ollama.com/bmizerany/smol:latest
|
|
||||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
|
||||||
// model@digest
|
|
||||||
// @digest
|
|
||||||
func splitExtended(s string) (scheme, name, digest string) {
|
|
||||||
i := strings.Index(s, "://")
|
|
||||||
if i >= 0 {
|
|
||||||
scheme = s[:i]
|
|
||||||
s = s[i+3:]
|
|
||||||
}
|
|
||||||
i = strings.LastIndex(s, "@")
|
|
||||||
if i >= 0 {
|
|
||||||
digest = s[i+1:]
|
|
||||||
s = s[:i]
|
|
||||||
}
|
|
||||||
return scheme, s, digest
|
|
||||||
}
|
|
||||||
|
|
||||||
type writerPool struct {
|
|
||||||
size int64 // set by the caller
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
ws []*bufio.Writer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *writerPool) get() *bufio.Writer {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
if len(p.ws) == 0 {
|
|
||||||
return bufio.NewWriterSize(nil, int(p.size))
|
|
||||||
}
|
|
||||||
w := p.ws[len(p.ws)-1]
|
|
||||||
p.ws = p.ws[:len(p.ws)-1]
|
|
||||||
return w
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *writerPool) put(w *bufio.Writer) {
|
|
||||||
p.mu.Lock()
|
|
||||||
defer p.mu.Unlock()
|
|
||||||
w.Reset(nil)
|
|
||||||
p.ws = append(p.ws, w)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package ollama
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"cmp"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -73,7 +72,6 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
c, err := blob.Open(t.TempDir())
|
c, err := blob.Open(t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -87,14 +85,13 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := &Registry{
|
r := &Registry{
|
||||||
Cache: c,
|
|
||||||
HTTPClient: &http.Client{
|
HTTPClient: &http.Client{
|
||||||
Transport: recordRoundTripper(h),
|
Transport: recordRoundTripper(h),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
link := func(name string, manifest string) {
|
link := func(name string, manifest string) {
|
||||||
n, err := r.parseName(name)
|
_, n, _, err := parseName(name, r.Mask)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -154,55 +151,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPushZero(t *testing.T) {
|
func TestPushZero(t *testing.T) {
|
||||||
rc, _ := newClient(t, okHandler)
|
rc, c := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), "empty", nil)
|
err := rc.Push(t.Context(), c, "empty", nil)
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushSingle(t *testing.T) {
|
func TestPushSingle(t *testing.T) {
|
||||||
rc, _ := newClient(t, okHandler)
|
rc, c := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), "single", nil)
|
err := rc.Push(t.Context(), c, "single", nil)
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushMultiple(t *testing.T) {
|
func TestPushMultiple(t *testing.T) {
|
||||||
rc, _ := newClient(t, okHandler)
|
rc, c := newClient(t, okHandler)
|
||||||
err := rc.Push(t.Context(), "multiple", nil)
|
err := rc.Push(t.Context(), c, "multiple", nil)
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushNotFound(t *testing.T) {
|
func TestPushNotFound(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Errorf("unexpected request: %v", r)
|
t.Errorf("unexpected request: %v", r)
|
||||||
})
|
})
|
||||||
err := rc.Push(t.Context(), "notfound", nil)
|
err := rc.Push(t.Context(), c, "notfound", nil)
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
if !errors.Is(err, fs.ErrNotExist) {
|
||||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushNullLayer(t *testing.T) {
|
func TestPushNullLayer(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, c := newClient(t, nil)
|
||||||
err := rc.Push(t.Context(), "null", nil)
|
err := rc.Push(t.Context(), c, "null", nil)
|
||||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushSizeMismatch(t *testing.T) {
|
func TestPushSizeMismatch(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, c := newClient(t, nil)
|
||||||
ctx, _ := withTraceUnexpected(t.Context())
|
ctx, _ := withTraceUnexpected(t.Context())
|
||||||
got := rc.Push(ctx, "sizemismatch", nil)
|
got := rc.Push(ctx, c, "sizemismatch", nil)
|
||||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||||
t.Errorf("err = %v; want size mismatch", got)
|
t.Errorf("err = %v; want size mismatch", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushInvalid(t *testing.T) {
|
func TestPushInvalid(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, c := newClient(t, nil)
|
||||||
err := rc.Push(t.Context(), "invalid", nil)
|
err := rc.Push(t.Context(), c, "invalid", nil)
|
||||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
@@ -210,7 +207,7 @@ func TestPushInvalid(t *testing.T) {
|
|||||||
|
|
||||||
func TestPushExistsAtRemote(t *testing.T) {
|
func TestPushExistsAtRemote(t *testing.T) {
|
||||||
var pushed bool
|
var pushed bool
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||||
if !pushed {
|
if !pushed {
|
||||||
// First push. Return an uploadURL.
|
// First push. Return an uploadURL.
|
||||||
@@ -238,35 +235,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
|||||||
|
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
err := rc.Push(ctx, "single", nil)
|
err := rc.Push(ctx, c, "single", nil)
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
if !errors.Is(errors.Join(errs...), nil) {
|
if !errors.Is(errors.Join(errs...), nil) {
|
||||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||||
}
|
}
|
||||||
|
|
||||||
err = rc.Push(ctx, "single", nil)
|
err = rc.Push(ctx, c, "single", nil)
|
||||||
check(err)
|
check(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushRemoteError(t *testing.T) {
|
func TestPushRemoteError(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
w.WriteHeader(500)
|
w.WriteHeader(500)
|
||||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), "single", nil)
|
got := rc.Push(t.Context(), c, "single", nil)
|
||||||
checkErrCode(t, got, 500, "blob_error")
|
checkErrCode(t, got, 500, "blob_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushLocationError(t *testing.T) {
|
func TestPushLocationError(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Location", ":///x")
|
w.Header().Set("Location", ":///x")
|
||||||
w.WriteHeader(http.StatusAccepted)
|
w.WriteHeader(http.StatusAccepted)
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), "single", nil)
|
got := rc.Push(t.Context(), c, "single", nil)
|
||||||
wantContains := "invalid upload URL"
|
wantContains := "invalid upload URL"
|
||||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||||
@@ -274,14 +271,14 @@ func TestPushLocationError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestPushUploadRoundtripError(t *testing.T) {
|
func TestPushUploadRoundtripError(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Host == "blob.store" {
|
if r.Host == "blob.store" {
|
||||||
w.WriteHeader(499) // force RoundTrip error on upload
|
w.WriteHeader(499) // force RoundTrip error on upload
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||||
})
|
})
|
||||||
got := rc.Push(t.Context(), "single", nil)
|
got := rc.Push(t.Context(), c, "single", nil)
|
||||||
if !errors.Is(got, errRoundTrip) {
|
if !errors.Is(got, errRoundTrip) {
|
||||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||||
}
|
}
|
||||||
@@ -297,20 +294,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
|||||||
os.Remove(c.GetFile(l.Digest))
|
os.Remove(c.GetFile(l.Digest))
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
got := rc.Push(ctx, "single", nil)
|
got := rc.Push(ctx, c, "single", nil)
|
||||||
if !errors.Is(got, fs.ErrNotExist) {
|
if !errors.Is(got, fs.ErrNotExist) {
|
||||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPushCommitRoundtripError(t *testing.T) {
|
func TestPushCommitRoundtripError(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
panic("unexpected")
|
panic("unexpected")
|
||||||
}
|
}
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
w.WriteHeader(499) // force RoundTrip error
|
||||||
})
|
})
|
||||||
err := rc.Push(t.Context(), "zero", nil)
|
err := rc.Push(t.Context(), c, "zero", nil)
|
||||||
if !errors.Is(err, errRoundTrip) {
|
if !errors.Is(err, errRoundTrip) {
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||||
}
|
}
|
||||||
@@ -324,8 +321,8 @@ func checkNotExist(t *testing.T, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullInvalidName(t *testing.T) {
|
func TestRegistryPullInvalidName(t *testing.T) {
|
||||||
rc, _ := newClient(t, nil)
|
rc, c := newClient(t, nil)
|
||||||
err := rc.Pull(t.Context(), "://")
|
err := rc.Pull(t.Context(), c, "://")
|
||||||
if !errors.Is(err, ErrNameInvalid) {
|
if !errors.Is(err, ErrNameInvalid) {
|
||||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||||
}
|
}
|
||||||
@@ -340,10 +337,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, resp := range cases {
|
for _, resp := range cases {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
io.WriteString(w, resp)
|
io.WriteString(w, resp)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), "x")
|
err := rc.Pull(t.Context(), c, "x")
|
||||||
if !errors.Is(err, ErrManifestInvalid) {
|
if !errors.Is(err, ErrManifestInvalid) {
|
||||||
t.Errorf("err = %v; want invalid manifest", err)
|
t.Errorf("err = %v; want invalid manifest", err)
|
||||||
}
|
}
|
||||||
@@ -366,18 +363,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Confirm that the layer does not exist locally
|
// Confirm that the layer does not exist locally
|
||||||
_, err := rc.ResolveLocal("model")
|
_, err := rc.ResolveLocal(c, "model")
|
||||||
checkNotExist(t, err)
|
checkNotExist(t, err)
|
||||||
|
|
||||||
_, err = c.Get(d)
|
_, err = c.Get(d)
|
||||||
checkNotExist(t, err)
|
checkNotExist(t, err)
|
||||||
|
|
||||||
err = rc.Pull(t.Context(), "model")
|
err = rc.Pull(t.Context(), c, "model")
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
mw, err := rc.Resolve(t.Context(), "model")
|
mw, err := rc.Resolve(t.Context(), "model")
|
||||||
check(err)
|
check(err)
|
||||||
mg, err := rc.ResolveLocal("model")
|
mg, err := rc.ResolveLocal(c, "model")
|
||||||
check(err)
|
check(err)
|
||||||
if !reflect.DeepEqual(mw, mg) {
|
if !reflect.DeepEqual(mw, mg) {
|
||||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||||
@@ -402,7 +399,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
|||||||
|
|
||||||
func TestRegistryPullCached(t *testing.T) {
|
func TestRegistryPullCached(t *testing.T) {
|
||||||
cached := blob.DigestFromBytes("exists")
|
cached := blob.DigestFromBytes("exists")
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||||
w.WriteHeader(499) // should not be called
|
w.WriteHeader(499) // should not be called
|
||||||
return
|
return
|
||||||
@@ -425,7 +422,7 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := rc.Pull(ctx, "single")
|
err := rc.Pull(ctx, c, "single")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{6}
|
want := []int64{6}
|
||||||
@@ -438,30 +435,30 @@ func TestRegistryPullCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), "notfound")
|
err := rc.Pull(t.Context(), c, "notfound")
|
||||||
checkErrCode(t, err, 404, "")
|
checkErrCode(t, err, 404, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), "single")
|
err := rc.Pull(t.Context(), c, "single")
|
||||||
checkErrCode(t, err, 500, "an_error")
|
checkErrCode(t, err, 500, "an_error")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||||
w.WriteHeader(499) // force RoundTrip error
|
w.WriteHeader(499) // force RoundTrip error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := rc.Pull(t.Context(), "single")
|
err := rc.Pull(t.Context(), c, "single")
|
||||||
if !errors.Is(err, errRoundTrip) {
|
if !errors.Is(err, errRoundTrip) {
|
||||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||||
}
|
}
|
||||||
@@ -514,7 +511,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
|
|
||||||
// Check that we pull all layers that we can.
|
// Check that we pull all layers that we can.
|
||||||
|
|
||||||
err := rc.Pull(ctx, "mixed")
|
err := rc.Pull(ctx, c, "mixed")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -532,7 +529,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistryPullChunking(t *testing.T) {
|
func TestRegistryPullChunking(t *testing.T) {
|
||||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||||
if r.URL.Host != "blob.store" {
|
if r.URL.Host != "blob.store" {
|
||||||
// The production registry redirects to the blob store.
|
// The production registry redirects to the blob store.
|
||||||
@@ -570,7 +567,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
err := rc.Pull(ctx, "remote")
|
err := rc.Pull(ctx, c, "remote")
|
||||||
testutil.Check(t, err)
|
testutil.Check(t, err)
|
||||||
|
|
||||||
want := []int64{0, 3, 6}
|
want := []int64{0, 3, 6}
|
||||||
@@ -608,7 +605,7 @@ func TestInsecureSkipVerify(t *testing.T) {
|
|||||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||||
_, err := rc.Resolve(t.Context(), url)
|
_, err := rc.Resolve(t.Context(), url)
|
||||||
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
||||||
t.Errorf("err = %v; want cert verification failure", err)
|
t.Errorf("err = %v; want cert verifiction failure", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
||||||
@@ -712,16 +709,25 @@ func TestErrorUnmarshal(t *testing.T) {
|
|||||||
//
|
//
|
||||||
// It is only for testing error messages, not that all invalids and valids are
|
// It is only for testing error messages, not that all invalids and valids are
|
||||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||||
func TestParseNameExtendedErrors(t *testing.T) {
|
func TestParseNameErrors(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
name string
|
name string
|
||||||
err error
|
err error
|
||||||
want string
|
want string
|
||||||
}{}
|
}{
|
||||||
|
{"x", nil, ""},
|
||||||
|
{"x@", nil, ""},
|
||||||
|
|
||||||
|
{"", ErrNameInvalid, `invalid or missing name: ""`},
|
||||||
|
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
|
||||||
|
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
|
||||||
|
|
||||||
|
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||||
|
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
|
||||||
|
}
|
||||||
|
|
||||||
var r Registry
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
_, _, _, err := r.parseNameExtended(tt.name)
|
_, _, _, err := parseName(tt.name, DefaultMask)
|
||||||
if !errors.Is(err, tt.err) {
|
if !errors.Is(err, tt.err) {
|
||||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||||
}
|
}
|
||||||
@@ -730,89 +736,3 @@ func TestParseNameExtendedErrors(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseNameExtended(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
in string
|
|
||||||
scheme string
|
|
||||||
name string
|
|
||||||
digest string
|
|
||||||
err string
|
|
||||||
}{
|
|
||||||
{in: "http://m", scheme: "http", name: "m"},
|
|
||||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
|
||||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
|
||||||
|
|
||||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
|
||||||
|
|
||||||
{in: "", err: "invalid or missing name"},
|
|
||||||
{in: "m", scheme: "https", name: "m"},
|
|
||||||
{in: "://", err: "invalid or missing name"},
|
|
||||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
|
||||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
|
||||||
}
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.in, func(t *testing.T) {
|
|
||||||
var r Registry
|
|
||||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
|
||||||
if err != nil {
|
|
||||||
if tt.err == "" {
|
|
||||||
t.Errorf("err = %v; want nil", err)
|
|
||||||
} else if !strings.Contains(err.Error(), tt.err) {
|
|
||||||
t.Errorf("err = %v; want %q", err, tt.err)
|
|
||||||
}
|
|
||||||
} else if tt.err != "" {
|
|
||||||
t.Errorf("err = nil; want %q", tt.err)
|
|
||||||
}
|
|
||||||
if err == nil && !n.IsFullyQualified() {
|
|
||||||
t.Errorf("name = %q; want fully qualified", n)
|
|
||||||
}
|
|
||||||
|
|
||||||
if scheme != tt.scheme {
|
|
||||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
|
||||||
}
|
|
||||||
|
|
||||||
// smoke-test name is superset of tt.name
|
|
||||||
if !strings.Contains(n.String(), tt.name) {
|
|
||||||
t.Errorf("name = %q; want %q", n, tt.name)
|
|
||||||
}
|
|
||||||
|
|
||||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
|
||||||
if digest.String() != tt.digest {
|
|
||||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnlink(t *testing.T) {
|
|
||||||
t.Run("found by name", func(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, nil)
|
|
||||||
|
|
||||||
// confirm linked
|
|
||||||
_, err := rc.ResolveLocal("single")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// unlink
|
|
||||||
_, err = rc.Unlink("single")
|
|
||||||
testutil.Check(t, err)
|
|
||||||
|
|
||||||
// confirm unlinked
|
|
||||||
_, err = rc.ResolveLocal("single")
|
|
||||||
if !errors.Is(err, fs.ErrNotExist) {
|
|
||||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("not found by name", func(t *testing.T) {
|
|
||||||
rc, _ := newClient(t, nil)
|
|
||||||
ok, err := rc.Unlink("manifestNotFound")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
t.Error("expected not found")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,20 +6,13 @@ import (
|
|||||||
|
|
||||||
// Trace is a set of functions that are called to report progress during blob
|
// Trace is a set of functions that are called to report progress during blob
|
||||||
// downloads and uploads.
|
// downloads and uploads.
|
||||||
//
|
|
||||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
|
||||||
// and [Registry.Pull].
|
|
||||||
type Trace struct {
|
type Trace struct {
|
||||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||||
// report the progress of blob uploads and downloads.
|
// report the progress of blob uploads and downloads.
|
||||||
//
|
//
|
||||||
// The n argument is the number of bytes transferred so far, and err is
|
// It is called once at the beginning of the download with a zero n and
|
||||||
// any error that has occurred. If n == 0, and err is nil, the download
|
// then once per read operation with the number of bytes read so far,
|
||||||
// or upload has just started. If err is [ErrCached], the download or
|
// and an error if any.
|
||||||
// upload has been skipped because the blob is already present in the
|
|
||||||
// local cache or remote registry, respectively. Otherwise, if err is
|
|
||||||
// non-nil, the download or upload has failed. When l.Size == n, and
|
|
||||||
// err is nil, the download or upload has completed.
|
|
||||||
//
|
//
|
||||||
// A function assigned must be safe for concurrent use. The function is
|
// A function assigned must be safe for concurrent use. The function is
|
||||||
// called synchronously and so should not block or take long to run.
|
// called synchronously and so should not block or take long to run.
|
||||||
|
|||||||
@@ -63,28 +63,25 @@ func main() {
|
|||||||
}
|
}
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
c, err := ollama.DefaultCache()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := ollama.DefaultRegistry()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
err := func() error {
|
err = func() error {
|
||||||
switch cmd := flag.Arg(0); cmd {
|
switch cmd := flag.Arg(0); cmd {
|
||||||
case "pull":
|
case "pull":
|
||||||
rc, err := ollama.DefaultRegistry()
|
return cmdPull(ctx, rc, c)
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmdPull(ctx, rc)
|
|
||||||
case "push":
|
case "push":
|
||||||
rc, err := ollama.DefaultRegistry()
|
return cmdPush(ctx, rc, c)
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
return cmdPush(ctx, rc)
|
|
||||||
case "import":
|
case "import":
|
||||||
c, err := ollama.DefaultCache()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
return cmdImport(ctx, c)
|
return cmdImport(ctx, c)
|
||||||
default:
|
default:
|
||||||
if cmd == "" {
|
if cmd == "" {
|
||||||
@@ -102,7 +99,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||||
model := flag.Arg(1)
|
model := flag.Arg(1)
|
||||||
if model == "" {
|
if model == "" {
|
||||||
flag.Usage()
|
flag.Usage()
|
||||||
@@ -148,7 +145,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
|||||||
|
|
||||||
errc := make(chan error)
|
errc := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
errc <- rc.Pull(ctx, model)
|
errc <- rc.Pull(ctx, c, model)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
t := time.NewTicker(time.Second)
|
t := time.NewTicker(time.Second)
|
||||||
@@ -164,7 +161,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||||
args := flag.Args()[1:]
|
args := flag.Args()[1:]
|
||||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||||
@@ -180,7 +177,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
from := cmp.Or(*flagFrom, model)
|
from := cmp.Or(*flagFrom, model)
|
||||||
m, err := rc.ResolveLocal(from)
|
m, err := rc.ResolveLocal(c, from)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -206,7 +203,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
return rc.Push(ctx, model, &ollama.PushParams{
|
return rc.Push(ctx, c, model, &ollama.PushParams{
|
||||||
From: from,
|
From: from,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build goexperiment.synctest
|
|
||||||
|
|
||||||
package backoff
|
package backoff
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build goexperiment.synctest
|
|
||||||
|
|
||||||
package syncs
|
package syncs
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -7,12 +7,9 @@ import (
|
|||||||
"cmp"
|
"cmp"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
@@ -30,15 +27,12 @@ import (
|
|||||||
// directly to the blob disk cache.
|
// directly to the blob disk cache.
|
||||||
type Local struct {
|
type Local struct {
|
||||||
Client *ollama.Registry // required
|
Client *ollama.Registry // required
|
||||||
|
Cache *blob.DiskCache // required
|
||||||
Logger *slog.Logger // required
|
Logger *slog.Logger // required
|
||||||
|
|
||||||
// Fallback, if set, is used to handle requests that are not handled by
|
// Fallback, if set, is used to handle requests that are not handled by
|
||||||
// this handler.
|
// this handler.
|
||||||
Fallback http.Handler
|
Fallback http.Handler
|
||||||
|
|
||||||
// Prune, if set, is called to prune the local disk cache after a model
|
|
||||||
// is deleted.
|
|
||||||
Prune func() error // optional
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// serverError is like ollama.Error, but with a Status field for the HTTP
|
// serverError is like ollama.Error, but with a Status field for the HTTP
|
||||||
@@ -113,8 +107,6 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
|||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/api/delete":
|
case "/api/delete":
|
||||||
return false, s.handleDelete(rec, r)
|
return false, s.handleDelete(rec, r)
|
||||||
case "/api/pull":
|
|
||||||
return false, s.handlePull(rec, r)
|
|
||||||
default:
|
default:
|
||||||
if s.Fallback != nil {
|
if s.Fallback != nil {
|
||||||
s.Fallback.ServeHTTP(rec, r)
|
s.Fallback.ServeHTTP(rec, r)
|
||||||
@@ -207,107 +199,13 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ok, err := s.Client.Unlink(p.model())
|
ok, err := s.Client.Unlink(s.Cache, p.model())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return &serverError{404, "not_found", "model not found"}
|
return &serverError{404, "not_found", "model not found"}
|
||||||
}
|
}
|
||||||
if s.Prune == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return s.Prune()
|
|
||||||
}
|
|
||||||
|
|
||||||
type progressUpdateJSON struct {
|
|
||||||
Status string `json:"status"`
|
|
||||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
|
||||||
Total int64 `json:"total,omitempty,omitzero"`
|
|
||||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != "POST" {
|
|
||||||
return errMethodNotAllowed
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := decodeUserJSON[*params](r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
maybeFlush := func() {
|
|
||||||
fl, _ := w.(http.Flusher)
|
|
||||||
if fl != nil {
|
|
||||||
fl.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer maybeFlush()
|
|
||||||
|
|
||||||
var mu sync.Mutex
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
|
||||||
|
|
||||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
|
||||||
Update: func(l *ollama.Layer, n int64, err error) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
// TODO(bmizerany): coalesce these updates; writing per
|
|
||||||
// update is expensive
|
|
||||||
enc.Encode(progressUpdateJSON{
|
|
||||||
Digest: l.Digest,
|
|
||||||
Status: "pulling",
|
|
||||||
Total: l.Size,
|
|
||||||
Completed: n,
|
|
||||||
})
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
// TODO(bmizerany): continue to support non-streaming responses
|
|
||||||
done <- s.Client.Pull(ctx, p.model())
|
|
||||||
}()
|
|
||||||
|
|
||||||
func() {
|
|
||||||
t := time.NewTicker(100 * time.Millisecond)
|
|
||||||
defer t.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-t.C:
|
|
||||||
mu.Lock()
|
|
||||||
maybeFlush()
|
|
||||||
mu.Unlock()
|
|
||||||
case err := <-done:
|
|
||||||
if err != nil {
|
|
||||||
var status string
|
|
||||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
|
||||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
|
||||||
enc.Encode(progressUpdateJSON{Status: status})
|
|
||||||
} else {
|
|
||||||
status = fmt.Sprintf("error: %v", err)
|
|
||||||
enc.Encode(progressUpdateJSON{Status: status})
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// These final updates are not strictly necessary, because they have
|
|
||||||
// already happened at this point. Our pull handler code used to do
|
|
||||||
// these steps after, not during, the pull, and they were slow, so we
|
|
||||||
// wanted to provide feedback to users what was happening. For now, we
|
|
||||||
// keep them to not jar users who are used to seeing them. We can phase
|
|
||||||
// them out with a new and nicer UX later. One without progress bars
|
|
||||||
// and digests that no one cares about.
|
|
||||||
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
|
||||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
|
||||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +1,17 @@
|
|||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/fs"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/testutil"
|
"github.com/ollama/ollama/server/internal/testutil"
|
||||||
"golang.org/x/tools/txtar"
|
|
||||||
|
|
||||||
_ "embed"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type panicTransport struct{}
|
type panicTransport struct{}
|
||||||
@@ -40,7 +30,7 @@ type bytesResetter interface {
|
|||||||
Reset()
|
Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
|
func newTestServer(t *testing.T) *Local {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
||||||
@@ -51,26 +41,11 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := panicOnRoundTrip
|
|
||||||
if upstreamRegistry != nil {
|
|
||||||
s := httptest.NewTLSServer(upstreamRegistry)
|
|
||||||
t.Cleanup(s.Close)
|
|
||||||
tr := s.Client().Transport.(*http.Transport).Clone()
|
|
||||||
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
var d net.Dialer
|
|
||||||
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
|
|
||||||
}
|
|
||||||
client = &http.Client{Transport: tr}
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := &ollama.Registry{
|
rc := &ollama.Registry{
|
||||||
Cache: c,
|
HTTPClient: panicOnRoundTrip,
|
||||||
HTTPClient: client,
|
|
||||||
Mask: "example.com/library/_:latest",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
l := &Local{
|
l := &Local{
|
||||||
|
Cache: c,
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: testutil.Slogger(t),
|
Logger: testutil.Slogger(t),
|
||||||
}
|
}
|
||||||
@@ -110,9 +85,9 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
|||||||
func TestServerDelete(t *testing.T) {
|
func TestServerDelete(t *testing.T) {
|
||||||
check := testutil.Checker(t)
|
check := testutil.Checker(t)
|
||||||
|
|
||||||
s := newTestServer(t, nil)
|
s := newTestServer(t)
|
||||||
|
|
||||||
_, err := s.Client.ResolveLocal("smol")
|
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
||||||
check(err)
|
check(err)
|
||||||
|
|
||||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||||
@@ -120,7 +95,7 @@ func TestServerDelete(t *testing.T) {
|
|||||||
t.Fatalf("Code = %d; want 200", got.Code)
|
t.Fatalf("Code = %d; want 200", got.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.Client.ResolveLocal("smol")
|
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected smol to have been deleted")
|
t.Fatal("expected smol to have been deleted")
|
||||||
}
|
}
|
||||||
@@ -152,105 +127,8 @@ func TestServerDelete(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//go:embed testdata/registry.txt
|
|
||||||
var registryTXT []byte
|
|
||||||
|
|
||||||
var registryFS = sync.OnceValue(func() fs.FS {
|
|
||||||
// Txtar gets hung up on \r\n line endings, so we need to convert them
|
|
||||||
// to \n when parsing the txtar on Windows.
|
|
||||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
|
||||||
a := txtar.Parse(data)
|
|
||||||
fmt.Printf("%q\n", a.Comment)
|
|
||||||
fsys, err := txtar.FS(a)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return fsys
|
|
||||||
})
|
|
||||||
|
|
||||||
func TestServerPull(t *testing.T) {
|
|
||||||
modelsHandler := http.FileServerFS(registryFS())
|
|
||||||
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.URL.Path {
|
|
||||||
case "/v2/library/BOOM/manifests/latest":
|
|
||||||
w.WriteHeader(999)
|
|
||||||
io.WriteString(w, `{"error": "boom"}`)
|
|
||||||
case "/v2/library/unknown/manifests/latest":
|
|
||||||
w.WriteHeader(404)
|
|
||||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
|
||||||
default:
|
|
||||||
t.Logf("serving file: %s", r.URL.Path)
|
|
||||||
modelsHandler.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if got.Code != 200 {
|
|
||||||
t.Fatalf("Code = %d; want 200", got.Code)
|
|
||||||
}
|
|
||||||
gotlines := got.Body.String()
|
|
||||||
t.Logf("got:\n%s", gotlines)
|
|
||||||
for want := range strings.Lines(wantlines) {
|
|
||||||
want = strings.TrimSpace(want)
|
|
||||||
want, unwanted := strings.CutPrefix(want, "!")
|
|
||||||
want = strings.TrimSpace(want)
|
|
||||||
if !unwanted && !strings.Contains(gotlines, want) {
|
|
||||||
t.Fatalf("! missing %q in body", want)
|
|
||||||
}
|
|
||||||
if unwanted && strings.Contains(gotlines, want) {
|
|
||||||
t.Fatalf("! unexpected %q in body", want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
|
||||||
checkResponse(got, `
|
|
||||||
{"status":"pulling manifest"}
|
|
||||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
|
||||||
`)
|
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
|
||||||
checkResponse(got, `
|
|
||||||
{"status":"pulling manifest"}
|
|
||||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
|
||||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
|
||||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
|
||||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
|
||||||
{"status":"verifying layers"}
|
|
||||||
{"status":"writing manifest"}
|
|
||||||
{"status":"success"}
|
|
||||||
`)
|
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
|
||||||
checkResponse(got, `
|
|
||||||
{"status":"pulling manifest"}
|
|
||||||
{"status":"error: model \"unknown\" not found"}
|
|
||||||
`)
|
|
||||||
|
|
||||||
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
|
|
||||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `!`)
|
|
||||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", ``)
|
|
||||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
|
||||||
|
|
||||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
|
||||||
checkResponse(got, `
|
|
||||||
{"status":"pulling manifest"}
|
|
||||||
{"status":"error: invalid or missing name: \"\""}
|
|
||||||
|
|
||||||
!verifying
|
|
||||||
!writing
|
|
||||||
!success
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServerUnknownPath(t *testing.T) {
|
func TestServerUnknownPath(t *testing.T) {
|
||||||
s := newTestServer(t, nil)
|
s := newTestServer(t)
|
||||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||||
}
|
}
|
||||||
|
|||||||
22
server/internal/registry/testdata/registry.txt
vendored
22
server/internal/registry/testdata/registry.txt
vendored
@@ -1,22 +0,0 @@
|
|||||||
-- v2/library/smol/manifests/latest --
|
|
||||||
{
|
|
||||||
"schemaVersion": 2,
|
|
||||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
|
||||||
"config": {
|
|
||||||
"mediaType": "application/vnd.docker.container.image.v1+json",
|
|
||||||
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
|
|
||||||
"size": 3
|
|
||||||
},
|
|
||||||
"layers": [
|
|
||||||
{
|
|
||||||
"mediaType": "application/vnd.ollama.image.model",
|
|
||||||
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
|
|
||||||
"size": 5
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
|
|
||||||
GGUF
|
|
||||||
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
|
|
||||||
{}
|
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/model/models/mllama"
|
"github.com/ollama/ollama/model/models/mllama"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@@ -26,7 +27,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
var system []api.Message
|
var system []api.Message
|
||||||
|
|
||||||
isMllama := checkMllamaModelFamily(m)
|
isMllama := checkMllamaModelFamily(m)
|
||||||
isGemma3 := checkGemma3ModelFamily(m)
|
|
||||||
|
|
||||||
var imageNumTokens int
|
var imageNumTokens int
|
||||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||||
@@ -41,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
n := len(msgs) - 1
|
n := len(msgs) - 1
|
||||||
// in reverse, find all messages that fit into context window
|
// in reverse, find all messages that fit into context window
|
||||||
for i := n; i >= 0; i-- {
|
for i := n; i >= 0; i-- {
|
||||||
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
|
if isMllama && len(msgs[i].Images) > 1 {
|
||||||
return "", nil, errTooManyImages
|
return "", nil, errTooManyImages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||||||
var imgData llm.ImageData
|
var imgData llm.ImageData
|
||||||
|
|
||||||
if isMllama {
|
if isMllama {
|
||||||
if len(m.ProjectorPaths) == 0 {
|
if envconfig.NewEngine() {
|
||||||
imgData = llm.ImageData{
|
imgData = llm.ImageData{
|
||||||
ID: len(images),
|
ID: len(images),
|
||||||
Data: i,
|
Data: i,
|
||||||
@@ -158,12 +158,3 @@ func checkMllamaModelFamily(m *Model) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkGemma3ModelFamily(m *Model) bool {
|
|
||||||
for _, arch := range m.Config.ModelFamilies {
|
|
||||||
if arch == "gemma3" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/model/models/mllama"
|
"github.com/ollama/ollama/model/models/mllama"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@@ -42,12 +43,6 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
func experimentEnabled(name string) bool {
|
|
||||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
|
||||||
}
|
|
||||||
|
|
||||||
var useClient2 = experimentEnabled("client2")
|
|
||||||
|
|
||||||
var mode string = gin.DebugMode
|
var mode string = gin.DebugMode
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
@@ -211,7 +206,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
|
|
||||||
images := make([]llm.ImageData, len(req.Images))
|
images := make([]llm.ImageData, len(req.Images))
|
||||||
for i := range req.Images {
|
for i := range req.Images {
|
||||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
if isMllama && !envconfig.NewEngine() {
|
||||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||||
@@ -435,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
kvData, _, err := getModelData(m.ModelPath, false)
|
kvData, err := getKVData(m.ModelPath, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -483,7 +478,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
slog.Error("embedding generation failed", "error", err)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -544,7 +540,8 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
|
|
||||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -848,23 +845,16 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
fmt.Fprint(&sb, m.String())
|
fmt.Fprint(&sb, m.String())
|
||||||
resp.Modelfile = sb.String()
|
resp.Modelfile = sb.String()
|
||||||
|
|
||||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
kvData, err := getKVData(m.ModelPath, req.Verbose)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(kvData, "general.name")
|
delete(kvData, "general.name")
|
||||||
delete(kvData, "tokenizer.chat_template")
|
delete(kvData, "tokenizer.chat_template")
|
||||||
resp.ModelInfo = kvData
|
resp.ModelInfo = kvData
|
||||||
|
|
||||||
tensorData := make([]api.Tensor, len(tensors.Items()))
|
|
||||||
for cnt, t := range tensors.Items() {
|
|
||||||
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
|
|
||||||
}
|
|
||||||
resp.Tensors = tensorData
|
|
||||||
|
|
||||||
if len(m.ProjectorPaths) > 0 {
|
if len(m.ProjectorPaths) > 0 {
|
||||||
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose)
|
projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -874,17 +864,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
func getKVData(digest string, verbose bool) (ggml.KV, error) {
|
||||||
maxArraySize := 0
|
maxArraySize := 0
|
||||||
if verbose {
|
if verbose {
|
||||||
maxArraySize = -1
|
maxArraySize = -1
|
||||||
}
|
}
|
||||||
data, err := llm.LoadModel(digest, maxArraySize)
|
kvData, err := llm.LoadModel(digest, maxArraySize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ggml.Tensors{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
kv := data.KV()
|
kv := kvData.KV()
|
||||||
|
|
||||||
if !verbose {
|
if !verbose {
|
||||||
for k := range kv {
|
for k := range kv {
|
||||||
@@ -894,7 +884,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv, data.Tensors(), nil
|
return kv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListHandler(c *gin.Context) {
|
func (s *Server) ListHandler(c *gin.Context) {
|
||||||
@@ -1139,7 +1129,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
||||||
corsConfig := cors.DefaultConfig()
|
corsConfig := cors.DefaultConfig()
|
||||||
corsConfig.AllowWildcard = true
|
corsConfig.AllowWildcard = true
|
||||||
corsConfig.AllowBrowserExtensions = true
|
corsConfig.AllowBrowserExtensions = true
|
||||||
@@ -1184,7 +1174,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.HEAD("/api/tags", s.ListHandler)
|
r.HEAD("/api/tags", s.ListHandler)
|
||||||
r.GET("/api/tags", s.ListHandler)
|
r.GET("/api/tags", s.ListHandler)
|
||||||
r.POST("/api/show", s.ShowHandler)
|
r.POST("/api/show", s.ShowHandler)
|
||||||
r.DELETE("/api/delete", s.DeleteHandler)
|
|
||||||
|
|
||||||
// Create
|
// Create
|
||||||
r.POST("/api/create", s.CreateHandler)
|
r.POST("/api/create", s.CreateHandler)
|
||||||
@@ -1206,19 +1195,15 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||||
|
|
||||||
if rc != nil {
|
// wrap old with new
|
||||||
// wrap old with new
|
rs := ®istry.Local{
|
||||||
rs := ®istry.Local{
|
Cache: c,
|
||||||
Client: rc,
|
Client: rc,
|
||||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||||
Fallback: r,
|
Fallback: r,
|
||||||
|
|
||||||
Prune: PruneLayers,
|
|
||||||
}
|
|
||||||
return rs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return rs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Serve(ln net.Listener) error {
|
func Serve(ln net.Listener) error {
|
||||||
@@ -1273,20 +1258,19 @@ func Serve(ln net.Listener) error {
|
|||||||
|
|
||||||
s := &Server{addr: ln.Addr()}
|
s := &Server{addr: ln.Addr()}
|
||||||
|
|
||||||
var rc *ollama.Registry
|
c, err := ollama.DefaultCache()
|
||||||
if useClient2 {
|
if err != nil {
|
||||||
var err error
|
return err
|
||||||
rc, err = ollama.DefaultRegistry()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
rc, err := ollama.DefaultRegistry()
|
||||||
h, err := s.GenerateRoutes(rc)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h, err := s.GenerateRoutes(c, rc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
http.Handle("/", h)
|
http.Handle("/", h)
|
||||||
|
|
||||||
ctx, done := context.WithCancel(context.Background())
|
ctx, done := context.WithCancel(context.Background())
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -489,6 +490,11 @@ func TestRoutes(t *testing.T) {
|
|||||||
modelsDir := t.TempDir()
|
modelsDir := t.TempDir()
|
||||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||||
|
|
||||||
|
c, err := blob.Open(modelsDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open models dir: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
rc := &ollama.Registry{
|
rc := &ollama.Registry{
|
||||||
// This is a temporary measure to allow us to move forward,
|
// This is a temporary measure to allow us to move forward,
|
||||||
// surfacing any code contacting ollama.com we do not intended
|
// surfacing any code contacting ollama.com we do not intended
|
||||||
@@ -505,7 +511,7 @@ func TestRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{}
|
s := &Server{}
|
||||||
router, err := s.GenerateRoutes(rc)
|
router, err := s.GenerateRoutes(c, rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to generate routes: %v", err)
|
t.Fatalf("failed to generate routes: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user