Compare commits

..

1 Commits

Author SHA1 Message Date
Blake Mizerany
9c1204b686 server/internal/internal/names: validate names
This commit is a step towards a goal to make names less ceremonial
outside of the registry client. Clients of the registry package can
treat names as opaque strings, and the registry package will handle
parsing, validating, and normalizing names.

Ideally we end up with the names package tucked away in an internal
package for good. We'll see how things go.

Also, this package name is not permanent. This another step in the
on-going process of refactoring the server code, and at some point it
will most likely be renamed/moved.
2025-02-28 16:30:42 -08:00
95 changed files with 1671 additions and 5294 deletions

View File

@@ -106,11 +106,9 @@ if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32) if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY) target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
endif() endif()
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES

View File

@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
RUN yum install -y yum-utils \ RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
@@ -86,11 +86,10 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel 8 && cmake --install build --component CUDA --strip --parallel 8
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama ARG GOVERSION=1.23.4
COPY go.mod go.sum . RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
ENV PATH=/usr/local/go/bin:$PATH ENV PATH=/usr/local/go/bin:$PATH
RUN go mod download WORKDIR /go/src/github.com/ollama/ollama
COPY . . COPY . .
ARG GOFLAGS="'-ldflags=-w -s'" ARG GOFLAGS="'-ldflags=-w -s'"
ENV CGO_ENABLED=1 ENV CGO_ENABLED=1

View File

@@ -1,5 +1,5 @@
<div align="center"> <div align="center">
  <a href="https://ollama.com">   <a href="https://ollama.com" />
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7"> <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a> </a>
</div> </div>
@@ -54,11 +54,6 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download | | Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | -------------------------------- | | ------------------ | ---------- | ----- | -------------------------------- |
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
| QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` | | DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` | | DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` | | Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
@@ -69,7 +64,10 @@ Here are some example models that can be downloaded:
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | | Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` | | Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 4 | 14B | 9.1GB | `ollama run phi4` | | Phi 4 | 14B | 9.1GB | `ollama run phi4` |
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` | | Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
| Mistral | 7B | 4.1GB | `ollama run mistral` | | Mistral | 7B | 4.1GB | `ollama run mistral` |
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` | | Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` | | Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
@@ -77,7 +75,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` | | Solar | 10.7B | 6.1GB | `ollama run solar` |
> [!NOTE] > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -277,7 +275,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Web & Desktop ### Web & Desktop
- [Open WebUI](https://github.com/open-webui/open-webui) - [Open WebUI](https://github.com/open-webui/open-webui)
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted) - [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
- [Hollama](https://github.com/fmaclen/hollama) - [Hollama](https://github.com/fmaclen/hollama)
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui) - [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
@@ -390,8 +387,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms) - [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Cloud ### Cloud
@@ -435,7 +430,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Apple Vision Pro ### Apple Vision Pro
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
### Database ### Database
@@ -513,13 +507,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Mobile ### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
### Extensions & Plugins ### Extensions & Plugins
@@ -565,14 +556,12 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language - [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai) - [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
### Supported backends ### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing. - [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics. - [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production. - [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.

View File

@@ -349,7 +349,6 @@ type ShowResponse struct {
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"` ModifiedAt time.Time `json:"modified_at,omitempty"`
} }
@@ -362,9 +361,9 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull]. // PullRequest is the request passed to [Client.Pull].
type PullRequest struct { type PullRequest struct {
Model string `json:"model"` Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` // Deprecated: ignored Username string `json:"username"`
Password string `json:"password"` // Deprecated: ignored Password string `json:"password"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead // Deprecated: set the model name with Model instead
@@ -468,13 +467,6 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"` QuantizationLevel string `json:"quantization_level"`
} }
// Tensor describes the metadata for a given tensor.
type Tensor struct {
Name string `json:"name"`
Type string `json:"type"`
Shape []uint64 `json:"shape"`
}
func (m *Metrics) Summary() { func (m *Metrics) Summary() {
if m.TotalDuration > 0 { if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration) fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)

View File

@@ -18,7 +18,6 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -35,6 +34,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
@@ -256,7 +256,6 @@ func StopHandler(cmd *cobra.Command, args []string) error {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0]) return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
} }
return err
} }
return nil return nil
} }
@@ -339,16 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
if len(info.ProjectorInfo) != 0 { // TODO(jessegross): We should either find another way to know if this is
opts.MultiModal = true // a vision model or remove the logic. Also consider that other modalities will
} // need different behavior anyways.
for k := range info.ModelInfo { opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
if interactive { if interactive {
@@ -569,9 +562,8 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
parameters, errParams := cmd.Flags().GetBool("parameters") parameters, errParams := cmd.Flags().GetBool("parameters")
system, errSystem := cmd.Flags().GetBool("system") system, errSystem := cmd.Flags().GetBool("system")
template, errTemplate := cmd.Flags().GetBool("template") template, errTemplate := cmd.Flags().GetBool("template")
verbose, errVerbose := cmd.Flags().GetBool("verbose")
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} { for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
if boolErr != nil { if boolErr != nil {
return errors.New("error retrieving flags") return errors.New("error retrieving flags")
} }
@@ -609,7 +601,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified") return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
} }
req := api.ShowRequest{Name: args[0], Verbose: verbose} req := api.ShowRequest{Name: args[0]}
resp, err := client.Show(cmd.Context(), &req) resp, err := client.Show(cmd.Context(), &req)
if err != nil { if err != nil {
return err return err
@@ -632,10 +624,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
return showInfo(resp, verbose, os.Stdout) return showInfo(resp, os.Stdout)
} }
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { func showInfo(resp *api.ShowResponse, w io.Writer) error {
tableRender := func(header string, rows func() [][]string) { tableRender := func(header string, rows func() [][]string) {
fmt.Fprintln(w, " ", header) fmt.Fprintln(w, " ", header)
table := tablewriter.NewWriter(w) table := tablewriter.NewWriter(w)
@@ -692,45 +684,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}) })
} }
if resp.ModelInfo != nil && verbose {
tableRender("Metadata", func() (rows [][]string) {
keys := make([]string, 0, len(resp.ModelInfo))
for k := range resp.ModelInfo {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case string:
v = vData
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
rows = append(rows, []string{"", k, v})
}
return
})
}
if len(resp.Tensors) > 0 && verbose {
tableRender("Tensors", func() (rows [][]string) {
for _, t := range resp.Tensors {
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
}
return
})
}
head := func(s string, n int) (rows [][]string) { head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s)) scanner := bufio.NewScanner(strings.NewReader(s))
for scanner.Scan() && (len(rows) < n || n < 0) { for scanner.Scan() && (len(rows) < n || n < 0) {
@@ -1237,7 +1190,6 @@ func NewCLI() *cobra.Command {
showCmd.Flags().Bool("parameters", false, "Show parameters of a model") showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
showCmd.Flags().Bool("template", false, "Show template of a model") showCmd.Flags().Bool("template", false, "Show template of a model")
showCmd.Flags().Bool("system", false, "Show system message of a model") showCmd.Flags().Bool("system", false, "Show system message of a model")
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
runCmd := &cobra.Command{ runCmd := &cobra.Command{
Use: "run MODEL [PROMPT]", Use: "run MODEL [PROMPT]",
@@ -1322,6 +1274,7 @@ func NewCLI() *cobra.Command {
runnerCmd := &cobra.Command{ runnerCmd := &cobra.Command{
Use: "runner", Use: "runner",
Short: llama.PrintSystemInfo(),
Hidden: true, Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
return runner.Execute(os.Args[1:]) return runner.Execute(os.Args[1:])

View File

@@ -27,7 +27,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -57,7 +57,7 @@ func TestShowInfo(t *testing.T) {
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -68,56 +68,6 @@ func TestShowInfo(t *testing.T) {
embedding length 0 embedding length 0
quantization FP16 quantization FP16
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
t.Run("verbose model", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "8B",
QuantizationLevel: "FP16",
},
Parameters: `
stop up`,
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
Tensors: []api.Tensor{
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
},
}, true, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 8B
context length 1000
embedding length 11434
quantization FP16
Parameters
stop up
Metadata
general.architecture test
general.parameter_count 8e+09
test.context_length 1000
test.embedding_length 11434
Tensors
blk.0.attn_k.weight BF16 [42 3117]
blk.0.attn_q.weight FP16 [3117 42]
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
@@ -139,7 +89,7 @@ func TestShowInfo(t *testing.T) {
stop you stop you
stop up stop up
temperature 99`, temperature 99`,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -176,7 +126,7 @@ func TestShowInfo(t *testing.T) {
"clip.vision.embedding_length": float64(0), "clip.vision.embedding_length": float64(0),
"clip.vision.projection_dim": float64(0), "clip.vision.projection_dim": float64(0),
}, },
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -209,7 +159,7 @@ func TestShowInfo(t *testing.T) {
Ahoy, matey! Ahoy, matey!
Weigh anchor! Weigh anchor!
`, `,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -238,7 +188,7 @@ Weigh anchor!
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
License: license, License: license,
}, false, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -195,10 +195,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err)
continue
}
return err return err
} }
continue continue
@@ -347,7 +343,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch args[1] { switch args[1] {
case "info": case "info":
_ = showInfo(resp, false, os.Stderr) _ = showInfo(resp, os.Stderr)
case "license": case "license":
if resp.License == "" { if resp.License == "" {
fmt.Println("No license was specified for this model.") fmt.Println("No license was specified for this model.")

View File

@@ -13,13 +13,8 @@ import (
) )
type ModelParameters struct { type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
type TextParameters struct {
VocabSize uint32 `json:"vocab_size"`
} }
type AdapterParameters struct { type AdapterParameters struct {
@@ -190,8 +185,6 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
conv = &gemmaModel{} conv = &gemmaModel{}
case "Gemma2ForCausalLM": case "Gemma2ForCausalLM":
conv = &gemma2Model{} conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Phi3ForCausalLM": case "Phi3ForCausalLM":
conv = &phi3Model{} conv = &phi3Model{}
case "Qwen2ForCausalLM": case "Qwen2ForCausalLM":
@@ -220,14 +213,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
} }
vocabSize := int(p.VocabSize) vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch { switch {
case vocabSize == 0:
slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens): case vocabSize > len(t.Vocabulary.Tokens):
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Vocabulary.Tokens) {

View File

@@ -45,7 +45,7 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor { func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") { if strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne) t.SetRepacker(p.addOne)
} }

View File

@@ -1,142 +0,0 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type gemma3Model struct {
gemmaModel
Architecture string
TextModel struct {
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
SlidingWindow uint32 `json:"sliding_window"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
ImageSize uint32 `json:"image_size"` // image_size 560
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3"
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
kv["gemma3.block_count"] = numBlocks
var (
numHeads uint32
numKVHeads uint32
)
switch numBlocks {
case gemma4BLayerCount:
numHeads = 8
numKVHeads = 4
case gemma12BLayerCount:
numHeads = 16
numKVHeads = 8
case gemma27BLayerCount:
numHeads = 32
numKVHeads = 16
default:
numHeads = p.NumAttentionHeads
numKVHeads = p.NumKeyValueHeads
}
kv["gemma3.attention.head_count"] = numHeads
kv["gemma3.attention.head_count_kv"] = numKVHeads
switch p.Architecture {
case "Gemma3ForCausalLM":
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
}
if p.MultiModalTokensPerImage > 0 {
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
}
return kv
}
func (p *gemma3Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"vision_tower.vision_model.embeddings", "v",
"vision_tower.vision_model", "v",
"vision_model.vision_model.embeddings", "v",
"vision_model.vision_model", "v",
"language_model.", "",
"model.layers", "blk",
"encoder.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.out_proj", "attn_output",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"input_projection_weight", "input_projection.weight",
"multi_modal_projector", "mm",
}
}

View File

@@ -6,9 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog"
"os" "os"
"reflect"
"slices" "slices"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@@ -17,8 +15,6 @@ import (
) )
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) { func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
slog.Debug("using spm vocabulary")
ast, err := parseAdditionalSpecialTokens(fsys) ast, err := parseAdditionalSpecialTokens(fsys)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -47,19 +43,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
v.Types = append(v.Types, int32(t)) v.Types = append(v.Types, int32(t))
default: default:
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
if slices.Contains(ast, piece.GetPiece()) {
// temporary fix to handle gemma3 broken configs
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
} }
for _, t := range ast {
if t.Content == piece.GetPiece() {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
break
}
}
v.Types = append(v.Types, tt) v.Types = append(v.Types, tt)
} }
} }
@@ -91,16 +78,10 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return cmp.Compare(i.id, j.id) return cmp.Compare(i.id, j.id)
}) })
for _, t := range ts { n := len(v.Tokens)
if t.id < len(v.Tokens) { for i, t := range ts {
if v.Tokens[t.id] == t.content { if t.id != i+n {
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id) return nil, fmt.Errorf("invalid token id: %d", t.id)
continue
}
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
}
if t.id != len(v.Tokens) {
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
} }
v.Tokens = append(v.Tokens, t.content) v.Tokens = append(v.Tokens, t.content)
@@ -111,15 +92,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
return &v, nil return &v, nil
} }
type specialToken struct { func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
Content string `json:"content"`
Lstrip bool `json:"lstrip"`
Normalized bool `json:"normalized"`
Rstrip bool `json:"rstrip"`
SingleWord bool `json:"single_word"`
}
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
f, err := fsys.Open("special_tokens_map.json") f, err := fsys.Open("special_tokens_map.json")
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
return nil, nil return nil, nil
@@ -129,43 +102,12 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
defer f.Close() defer f.Close()
var m struct { var m struct {
AdditionalSpecialTokens any `json:"additional_special_tokens"` AdditionalSpecialTokens []string `json:"additional_special_tokens"`
} }
if err := json.NewDecoder(f).Decode(&m); err != nil { if err := json.NewDecoder(f).Decode(&m); err != nil {
return nil, err return nil, err
} }
var ast []specialToken return m.AdditionalSpecialTokens, nil
switch st := m.AdditionalSpecialTokens.(type) {
case []string:
for _, s := range st {
ast = append(ast, specialToken{Content: s})
}
case []any:
for _, s := range st {
// marshal and unmarshal the object to get the special token
tMap := s.(map[string]any)
data, err := json.Marshal(tMap)
if err != nil {
return nil, err
}
var token specialToken
err = json.Unmarshal(data, &token)
if err != nil {
return nil, err
}
ast = append(ast, token)
}
default:
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
}
slog.Debug("spm tokenizer", "additional tokens", ast)
return ast, nil
} }

View File

@@ -118,35 +118,6 @@ To run tests, use `go test`:
go test ./... go test ./...
``` ```
> NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24.
>
> If you do not have the "synctest" package enabled, you will not see build or
> test failures resulting from your change(s), if any, locally, but CI will
> break.
>
> If you see failures in CI, you can either keep pushing changes to see if the
> CI build passes, or you can enable the "synctest" package locally to see the
> failures before pushing.
>
> To enable the "synctest" package for testing, run the following command:
>
> ```shell
> GOEXPERIMENT=synctest go test ./...
> ```
>
> If you wish to enable synctest for all go commands, you can set the
> `GOEXPERIMENT` environment variable in your shell profile or by using:
>
> ```shell
> go env -w GOEXPERIMENT=synctest
> ```
>
> Which will enable the "synctest" package for all go commands without needing
> to set it for all shell sessions.
>
> The synctest package is not required for production builds.
## Library detection ## Library detection
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable: Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:

View File

@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`. By default, Ollama uses a context window size of 2048 tokens.
To change this when using `ollama run`, use `/set parameter`: To change this when using `ollama run`, use `/set parameter`:

View File

@@ -75,7 +75,7 @@ RestartSec=3
Environment="PATH=$PATH" Environment="PATH=$PATH"
[Install] [Install]
WantedBy=multi-user.target WantedBy=default.target
``` ```
Then start the service: Then start the service:

View File

@@ -81,11 +81,9 @@ help you keep up to date.
If you'd like to install or integrate Ollama as a service, a standalone If you'd like to install or integrate Ollama as a service, a standalone
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI `ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download and GPU library dependencies for Nvidia and AMD. This allows for embedding
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the Ollama in existing applications, or running it as a system service via `ollama
same directory. This allows for embedding Ollama in existing applications, or serve` with tools such as [NSSM](https://nssm.cc/).
running it as a system service via `ollama serve` with tools such as
[NSSM](https://nssm.cc/).
> [!NOTE] > [!NOTE]
> If you are upgrading from a prior version, you should remove the old directories first. > If you are upgrading from a prior version, you should remove the old directories first.

View File

@@ -124,19 +124,6 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s return s
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
r := keyValue(kv, key, &array{})
s := make([]float32, r.size)
for i := range r.size {
s[i] = float32(r.values[i].(float32))
}
return s
}
func (kv KV) OllamaEngineRequired() bool {
return kv.Architecture() == "gemma3"
}
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T { func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
@@ -327,10 +314,6 @@ func (t Tensor) Size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize() return t.parameters() * t.typeSize() / t.blockSize()
} }
func (t Tensor) Type() string {
return fileType(t.Kind).String()
}
type container interface { type container interface {
Name() string Name() string
Decode(io.ReadSeeker) (model, error) Decode(io.ReadSeeker) (model, error)
@@ -493,7 +476,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
// vocab graph // vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )
case "gemma", "gemma2", "gemma3": case "gemma", "gemma2":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -582,43 +565,6 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
return return
} }
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
switch llm.KV().Architecture() {
case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
graphSize = 4 * (8 +
imageSize*imageSize*kv("num_channels")*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported // SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool { func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)

3
go.mod
View File

@@ -24,7 +24,7 @@ require (
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0 gonum.org/v1/gonum v0.15.0
) )
require ( require (
@@ -44,7 +44,6 @@ require (
github.com/xtgo/set v1.0.0 // indirect github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect
) )

2
go.sum
View File

@@ -309,8 +309,6 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -4,7 +4,6 @@ import (
"errors" "errors"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
var ( var (
@@ -30,17 +29,6 @@ type Cache interface {
// cache implementation used. // cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor) Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management ** // ** cache management **
// Init sets up runtime parameters // Init sets up runtime parameters
@@ -52,7 +40,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass. // StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding // For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. // entry in positions and seqs.
StartForward(ctx ml.Context, opts input.Options) error StartForward(ctx ml.Context, positions []int32, seqs []int) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32) CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -8,7 +8,6 @@ import (
"slices" "slices"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@@ -23,11 +22,6 @@ type Causal struct {
Capacity int32 Capacity int32
windowSize int32 windowSize int32
opts CausalOptions
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@@ -45,12 +39,6 @@ type Causal struct {
// locations in the cache that are needed for this batch // locations in the cache that are needed for this batch
curCellRange cellRange curCellRange cellRange
// curSequences is the sequences corresponding to this pass's entries in the cache
curSequences []int
// curPositions is the positions corresponding to this pass's entries in the cache
curPositions []int32
// ** cache metadata ** // ** cache metadata **
// for each possible location in the cache, stores the position and set of sequences // for each possible location in the cache, stores the position and set of sequences
@@ -64,8 +52,8 @@ type Causal struct {
shiftFn shiftFn shiftFn shiftFn
backend ml.Backend backend ml.Backend
ctxs map[int]ml.Context cacheCtx ml.Context
keys, values map[int]ml.Tensor keys, values []ml.Tensor
} }
type cacheCell struct { type cacheCell struct {
@@ -79,72 +67,28 @@ type cellRange struct {
} }
func NewCausalCache(shift shiftFn) *Causal { func NewCausalCache(shift shiftFn) *Causal {
return &Causal{ return &Causal{windowSize: math.MaxInt32, shiftFn: shift}
windowSize: math.MaxInt32,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func NewSWACache(windowSize int32, shift shiftFn) *Causal { func NewSWACache(windowSize int32, shift shiftFn) *Causal {
return &Causal{ return &Causal{windowSize: windowSize, shiftFn: shift}
windowSize: windowSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
c.DType = dtype c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding)) c.Capacity = capacity
c.cells = make([]cacheCell, c.Capacity) c.cells = make([]cacheCell, capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
} c.cacheCtx = backend.NewContext()
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
} }
func (c *Causal) Close() { func (c *Causal) Close() {
for _, ctx := range c.ctxs { c.cacheCtx.Close()
ctx.Close()
}
} }
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
c.curBatchSize = len(opts.Positions) c.curBatchSize = len(positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
c.opts.Except = nil
var err error var err error
c.curLoc, err = c.findStartLoc() c.curLoc, err = c.findStartLoc()
@@ -157,8 +101,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
} }
c.curCellRange = newRange() c.curCellRange = newRange()
for i, pos := range opts.Positions { for i, pos := range positions {
seq := opts.Sequences[i] seq := seqs[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -183,7 +127,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.cellRanges[seq] = seqRange c.cellRanges[seq] = seqRange
} }
c.curMask, err = c.buildMask(ctx) c.curMask, err = c.buildMask(ctx, positions, seqs)
return err return err
} }
@@ -213,91 +157,36 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
} }
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend // TODO(jessegross): This does not do padding, which is required for flash attention
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) len := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*len)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
enabled := !slices.Contains(c.opts.Except, i)
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
(enabled && c.cells[j].pos > c.curPositions[i]) || c.cells[j].pos < positions[i]-c.windowSize {
c.cells[j].pos < c.curPositions[i]-c.windowSize { mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }
} }
} }
// Mask out any padding tokens we added. For padding that we added to the cache history, this return ctx.FromFloatSlice(mask, len, c.curBatchSize)
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for i, key := range c.keys { for _, obj := range objs {
if key == nil { if obj == nil {
continue continue
} }
kHeadDim := key.Dim(0) srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
numKVHeads := key.Dim(1) dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len) ctx.Forward(srcView.Copy(ctx, dstView))
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
} }
} }
@@ -330,7 +219,7 @@ func (c *Causal) defrag() {
layers++ layers++
} }
maxMoves := ctx.MaxGraphNodes() / (6 * layers) maxMoves := ctx.MaxTensors() / (6 * layers)
moves := 0 moves := 0
var pendingSrc, pendingDst, pendingLen int var pendingSrc, pendingDst, pendingLen int
@@ -349,7 +238,8 @@ func (c *Causal) defrag() {
pendingLen++ pendingLen++
break break
} else { } else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
} }
@@ -373,7 +263,8 @@ func (c *Causal) defrag() {
} }
if pendingLen > 0 { if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen) moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
@@ -402,107 +293,47 @@ func (c *Causal) defrag() {
} }
func (c *Causal) SetLayer(layer int) { func (c *Causal) SetLayer(layer int) {
c.curLayer = layer if layer >= len(c.keys) {
} c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
type CausalOptions struct {
// Enabled controls whether the causal mask is generated for a particular index in a batch
Except []int
}
// SetCausal disables causal mask generation for a particular range of indicies in
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts
if ctx != nil {
var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
}
} }
c.curLayer = layer
} }
func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer] key := c.keys[c.curLayer]
value := c.values[c.curLayer] value := c.values[c.curLayer]
kHeadDim := key.Dim(0) key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
numKVHeads := key.Dim(1) key.Dim(0), key.Stride(1),
rowSize := key.Stride(2) key.Dim(1), key.Stride(2),
cachedSize := c.curMask.Dim(0) c.curMask.Dim(0),
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
) )
if c.config.PermutedV { value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
vHeadDim := value.Dim(1) value.Dim(0), value.Stride(1),
elemSize := value.Stride(0) value.Dim(1), value.Stride(2),
c.curMask.Dim(0),
value = value.View(ctx, elemSize*c.curCellRange.min, )
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask return key, value, c.curMask
} }
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
kHeadDim := key.Dim(0) if c.curBatchSize != key.Dim(2) {
vHeadDim := value.Dim(0) panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
} }
if _, ok := c.ctxs[c.curLayer]; !ok { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer) c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
} }
if _, ok := c.keys[c.curLayer]; !ok { ctx.Forward(
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity)) key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
} value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
)
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
} }
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -548,7 +379,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
} }
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) kShift, err := ctx.FromIntSlice(offsets, len(offsets))
if err != nil { if err != nil {
return err return err
} }
@@ -558,13 +389,9 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue continue
} }
kHeadDim := key.Dim(0) key = key.View(ctx, key.Stride(2)*seqRange.min,
numKVHeads := key.Dim(1) key.Dim(0), key.Stride(1),
rowSize := key.Stride(2) key.Dim(1), key.Stride(2),
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size, size,
) )

View File

@@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
type testCase struct { type testCase struct {
@@ -270,7 +269,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) err := cache.StartForward(context, test.pos, test.seqs)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -304,17 +303,13 @@ func (b *testBackend) NewContext() ml.Context {
return &testContext{} return &testContext{}
} }
func (b *testBackend) NewContextSize(int) ml.Context {
return &testContext{}
}
func (b *testBackend) SystemInfo() string { func (b *testBackend) SystemInfo() string {
return "not implemented" return "not implemented"
} }
type testContext struct{} type testContext struct{}
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
total := 0 total := 0
if len(shape) > 0 { if len(shape) > 0 {
@@ -327,12 +322,8 @@ func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
} }
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
@@ -351,15 +342,11 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil return out, nil
} }
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Output() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) MaxGraphNodes() int { func (c *testContext) MaxTensors() int {
return 10 return 10
} }
@@ -404,7 +391,7 @@ func (t *testTensor) Floats() []float32 {
} }
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor) out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data { for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i] out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -441,19 +428,11 @@ func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
@@ -489,7 +468,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{} context := &testContext{}
view := context.Empty(t.dtype, s...).(*testTensor) view := context.Zeros(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)] view.data = t.data[offset : offset+len(view.data)]
return view return view
@@ -503,10 +482,6 @@ func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
panic("not implemented") panic("not implemented")
} }
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented") panic("not implemented")
} }

View File

@@ -1,10 +1,7 @@
package kvcache package kvcache
import ( import (
"fmt"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
// Encoder cache stores K and V tensors that are position independent // Encoder cache stores K and V tensors that are position independent
@@ -14,9 +11,6 @@ import (
// //
// Not currently safe for multiple sequences // Not currently safe for multiple sequences
type EncoderCache struct { type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@@ -36,59 +30,36 @@ type EncoderCache struct {
encoderPos int32 encoderPos int32
// ** cache data storage ** // ** cache data storage **
backend ml.Backend
ctxs map[int]ml.Context cacheCtx ml.Context
keys, values map[int]ml.Tensor keys, values []ml.Tensor
} }
func NewEncoderCache() *EncoderCache { func NewEncoderCache() *EncoderCache {
return &EncoderCache{ return &EncoderCache{}
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil { c.cacheCtx = backend.NewContext()
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.backend = backend
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
} }
func (c *EncoderCache) Close() { func (c *EncoderCache) Close() {
for _, ctx := range c.ctxs { c.cacheCtx.Close()
ctx.Close()
}
} }
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
// We work with the most recent image // The image is always in the first position
if len(opts.Multimodal) > 0 { c.curPos = positions[0]
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
}
return nil return nil
} }
func (c *EncoderCache) SetLayer(layer int) { func (c *EncoderCache) SetLayer(layer int) {
if layer >= len(c.keys) {
c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...)
c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...)
}
c.curLayer = layer c.curLayer = layer
} }
@@ -104,20 +75,9 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos c.encoderPos = c.curPos
c.encoderCached = true c.encoderCached = true
if c.config.PermutedV { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
value = value.Permute(ctx, 1, 2, 0, 3) c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
} c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
if _, ok := c.ctxs[c.curLayer]; !ok {
c.ctxs[c.curLayer] = c.backend.NewContextSize(2).Layer(c.curLayer)
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
}
if _, ok := c.values[c.curLayer]; !ok {
c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
} }
ctx.Forward( ctx.Forward(

View File

@@ -4,7 +4,6 @@ import (
"math" "math"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
// Wrapper cache is a container for multiple types of caches, // Wrapper cache is a container for multiple types of caches,
@@ -29,26 +28,20 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
} }
} }
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
func (c *WrapperCache) Close() { func (c *WrapperCache) Close() {
for _, cache := range c.caches { for _, cache := range c.caches {
cache.Close() cache.Close()
} }
} }
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error {
for i, cache := range c.caches { for i, cache := range c.caches {
err := cache.StartForward(ctx, opts) err := cache.StartForward(ctx, positions, seqs)
if err != nil { if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {
for k := range opts.Positions { for k := range positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32)
} }
} }
return err return err

View File

@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) { if (precompiled_charsmap_keyidx != -1) {
size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx); size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN #ifdef IS_BIG_ENDIAN

View File

@@ -21,6 +21,18 @@ package llama
extern bool llamaProgressCallback(float progress, void *user_data); extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data); extern void llamaLog(int level, char* text, void* user_data);
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
COMPILER inline get_compiler() {
#if defined(__clang__)
return COMP_CLANG;
#elif defined(__GNUC__)
return COMP_GCC;
#else
return UNKNOWN_COMPILER;
#endif
}
*/ */
import "C" import "C"
@@ -60,6 +72,19 @@ func BackendInit() {
C.llama_backend_init() C.llama_backend_init()
} }
func PrintSystemInfo() string {
var compiler string
switch C.get_compiler() {
case C.COMP_UNKNOWN:
compiler = "cgo(unknown_compiler)"
case C.COMP_GCC:
compiler = "cgo(gcc)"
case C.COMP_CLANG:
compiler = "cgo(clang)"
}
return C.GoString(C.llama_print_system_info()) + compiler
}
func GetModelArch(modelPath string) (string, error) { func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath) mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp)) defer C.free(unsafe.Pointer(mp))
@@ -245,20 +270,6 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
return &m, nil return &m, nil
} }
func LoadVocabFromFile(path string) (*Vocab, error) {
mp := C.CString(path)
defer C.free(unsafe.Pointer(mp))
v := Vocab{c: C.llama_load_vocab_from_file(mp)}
if v.c == nil {
return nil, fmt.Errorf("unable to load vocab: %s", path)
}
return &v, nil
}
func FreeVocab(vocab *Vocab) {
C.llama_free_vocab(vocab.c)
}
func FreeModel(model *Model) { func FreeModel(model *Model) {
C.llama_model_free(model.c) C.llama_model_free(model.c)
} }
@@ -307,10 +318,6 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
return nil return nil
} }
type Vocab struct {
c *C.struct_llama_vocab
}
func (m *Model) Vocab() *C.struct_llama_vocab { func (m *Model) Vocab() *C.struct_llama_vocab {
return C.llama_model_get_vocab(m.c) return C.llama_model_get_vocab(m.c)
} }
@@ -687,53 +694,3 @@ func SchemaToGrammar(schema []byte) []byte {
} }
return buf[:n] return buf[:n]
} }
type Sampler struct {
c *C.struct_llama_sampler
}
func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
cGrammar := C.CString(grammar)
cRoot := C.CString("root")
defer C.free(unsafe.Pointer(cGrammar))
defer C.free(unsafe.Pointer(cRoot))
sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
return sampler
}
func (s *Sampler) Accept(token int32) {
C.llama_sampler_accept(s.c, C.llama_token(token))
}
type TokenData struct {
Id int32
Logit float32
}
func (s *Sampler) Apply(tokens []TokenData) {
tds := make([]C.struct_llama_token_data, len(tokens))
for i, token := range tokens {
tds[i] = C.struct_llama_token_data{
id: C.int32_t(token.Id),
logit: C.float(token.Logit),
p: C.float(0.0),
}
}
tda := &C.llama_token_data_array{
data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
size: C.size_t(len(tokens)),
selected: C.int64_t(-1),
sorted: C.bool(false),
}
var pinner runtime.Pinner
pinner.Pin(&tds[0])
defer pinner.Unpin()
C.llama_sampler_apply(s.c, tda)
for i := range tokens {
tokens[i].Logit = float32(tds[i].logit)
}
}

View File

@@ -0,0 +1,69 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Tue, 11 Feb 2025 14:06:36 -0800
Subject: [PATCH] try/catch backend load
---
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
1 file changed, 23 insertions(+), 22 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 98d5e14d..1c19129a 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
- if (entry.is_regular_file()) {
- std::wstring filename = entry.path().filename().wstring();
- std::wstring ext = entry.path().extension().wstring();
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
- if (!handle && !silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
- if (handle) {
+ try {
+ if (entry.is_regular_file()) {
+ std::wstring filename = entry.path().filename().wstring();
+ std::wstring ext = entry.path().extension().wstring();
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+ if (!handle) {
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ continue;
+ }
+
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (score_fn) {
- int s = score_fn();
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-#endif
- if (s > best_score) {
- best_score = s;
- best_path = entry.path().wstring();
- }
- } else {
- if (!silent) {
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
+ if (!score_fn) {
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ continue;
+ }
+
+ int s = score_fn();
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ if (s > best_score) {
+ best_score = s;
+ best_path = entry.path().wstring();
}
}
}
+ } catch (const std::exception & e) {
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
}
}
}

View File

@@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
Subject: [PATCH] use std::filesystem::path instead of wstring Subject: [PATCH] use std::filesystem::path instead of wstring
--- ---
ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++------------------- ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
1 file changed, 88 insertions(+), 111 deletions(-) 1 file changed, 58 insertions(+), 86 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 98d5e14d..799af5f3 100644 index 1c19129a..c854e6bb 100644
--- a/ggml/src/ggml-backend-reg.cpp --- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp
@@ -66,26 +66,6 @@ @@ -66,26 +66,6 @@
@@ -264,55 +264,47 @@ index 98d5e14d..799af5f3 100644
for (const auto & search_path : search_paths) { for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) { if (!fs::exists(search_path)) {
continue; continue;
@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, @@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) { for (const auto & entry : dir_it) {
if (entry.is_regular_file()) { try {
- std::wstring filename = entry.path().filename().wstring(); if (entry.is_regular_file()) {
- std::wstring ext = entry.path().extension().wstring(); - std::wstring filename = entry.path().filename().wstring();
+ std::string filename = entry.path().filename().string(); - std::wstring ext = entry.path().extension().wstring();
+ std::string ext = entry.path().extension().string(); + std::string filename = entry.path().filename().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { + std::string ext = entry.path().extension().string();
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- if (!handle && !silent) { - dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + dl_handle_ptr handle { dl_load_library(entry.path()) };
+ dl_handle_ptr handle { dl_load_library(entry.path()) }; if (!handle) {
+ if (!handle) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
+ continue; continue;
} }
- if (handle) {
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (score_fn) { if (!score_fn) {
- int s = score_fn(); - GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
-#ifndef NDEBUG + GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); continue;
-#endif }
- if (s > best_score) {
- best_score = s; int s = score_fn();
- best_path = entry.path().wstring(); - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
- } + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
- } else { if (s > best_score) {
- if (!silent) { best_score = s;
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); - best_path = entry.path().wstring();
- } + best_path = entry.path();
- } }
+
+ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
+ if (!score_fn) {
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
+ continue;
+ }
+
+ int s = score_fn();
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
+ if (s > best_score) {
+ best_score = s;
+ best_path = entry.path();
} }
} }
} catch (const std::exception & e) {
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
} }
@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, }
}
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) { if (best_score == 0) {
// try to load the base backend // try to load the base backend
for (const auto & search_path : search_paths) { for (const auto & search_path : search_paths) {
@@ -321,49 +313,3 @@ index 98d5e14d..799af5f3 100644
if (fs::exists(path)) { if (fs::exists(path)) {
return get_reg().load_backend(path, silent); return get_reg().load_backend(path, silent);
} }
@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
ggml_backend_load_all_from_path(nullptr);
}
+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
+ try {
+ ggml_backend_load_best(name, silent, user_search_path);
+ } catch (const std::exception & e) {
+ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
+ }
+}
+
void ggml_backend_load_all_from_path(const char * dir_path) {
#ifdef NDEBUG
bool silent = true;
@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
bool silent = false;
#endif
- ggml_backend_load_best("blas", silent, dir_path);
- ggml_backend_load_best("cann", silent, dir_path);
- ggml_backend_load_best("cuda", silent, dir_path);
- ggml_backend_load_best("hip", silent, dir_path);
- ggml_backend_load_best("kompute", silent, dir_path);
- ggml_backend_load_best("metal", silent, dir_path);
- ggml_backend_load_best("rpc", silent, dir_path);
- ggml_backend_load_best("sycl", silent, dir_path);
- ggml_backend_load_best("vulkan", silent, dir_path);
- ggml_backend_load_best("opencl", silent, dir_path);
- ggml_backend_load_best("musa", silent, dir_path);
- ggml_backend_load_best("cpu", silent, dir_path);
+ ggml_backend_try_load_best("blas", silent, dir_path);
+ ggml_backend_try_load_best("cann", silent, dir_path);
+ ggml_backend_try_load_best("cuda", silent, dir_path);
+ ggml_backend_try_load_best("hip", silent, dir_path);
+ ggml_backend_try_load_best("kompute", silent, dir_path);
+ ggml_backend_try_load_best("metal", silent, dir_path);
+ ggml_backend_try_load_best("rpc", silent, dir_path);
+ ggml_backend_try_load_best("sycl", silent, dir_path);
+ ggml_backend_try_load_best("vulkan", silent, dir_path);
+ ggml_backend_try_load_best("opencl", silent, dir_path);
+ ggml_backend_try_load_best("musa", silent, dir_path);
+ ggml_backend_try_load_best("cpu", silent, dir_path);
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
if (backend_path) {

View File

@@ -1,64 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Wed, 5 Mar 2025 17:41:07 -0800
Subject: [PATCH] fix string arr kv loading
---
ggml/include/gguf.h | 1 +
ggml/src/gguf.cpp | 7 +++++--
src/llama-vocab.cpp | 2 +-
3 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
index 79ee2020..3efb22f0 100644
--- a/ggml/include/gguf.h
+++ b/ggml/include/gguf.h
@@ -114,6 +114,7 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+ GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
index ab13669c..f75b923f 100644
--- a/ggml/src/gguf.cpp
+++ b/ggml/src/gguf.cpp
@@ -777,10 +777,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
+size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
+ GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+ return ctx->kv[key_id].data.size();
+}
+
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -874,7 +878,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
- GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data();
}
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index c7ff28be..7a185443 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -1443,7 +1443,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
if (precompiled_charsmap_keyidx != -1) {
- size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+ size_t n_precompiled_charsmap = gguf_get_arr_data_n(ctx, precompiled_charsmap_keyidx);
const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
#ifdef IS_BIG_ENDIAN

View File

@@ -1,33 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Sun, 9 Mar 2025 14:44:16 -0700
Subject: [PATCH] ollama debug tensor
---
ggml/src/ggml-cpu/ggml-cpu.c | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index 2f606d82..ec60e8fc 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -11,6 +11,8 @@
#include "ggml-threading.h"
#include "ggml.h"
+#include "ollama-debug.h"
+
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -14103,6 +14105,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node);
+#ifdef OLLAMA_DEBUG
+ ollama_debug(node, true);
+#endif
+
if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) {
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

View File

@@ -2,9 +2,6 @@
#include "sampling.h" #include "sampling.h"
#include "sampling_ext.h" #include "sampling_ext.h"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h"
#include "llama-model.h"
#include "llama-model-loader.h"
struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) { struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) {
try { try {
@@ -67,22 +64,3 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len)
return 0; return 0;
} }
} }
struct llama_vocab * llama_load_vocab_from_file(const char * fname) {
llama_vocab * vocab = new llama_vocab();
try {
const auto kv = LLM_KV(LLM_ARCH_UNKNOWN);
std::vector<std::string> splits = {};
llama_model_loader ml(std::string(fname), splits, false, false, nullptr);
vocab->load(ml, kv);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return nullptr;
}
return vocab;
}
void llama_free_vocab(struct llama_vocab * vocab) {
delete vocab;
}

View File

@@ -35,9 +35,6 @@ extern "C"
int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len); int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len);
struct llama_vocab * llama_load_vocab_from_file(const char * fname);
void llama_free_vocab(struct llama_vocab * vocab);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@@ -115,9 +115,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// multimodal models require at least 2048 context // multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048) opts.NumCtx = max(opts.NumCtx, 2048)
} }
if projectorWeights == 0 && projectorGraph == 0 {
projectorWeights, projectorGraph = f.VisionGraphSize()
}
layers := f.Tensors().GroupLayers() layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer // add one layer worth of memory as a buffer

View File

@@ -30,7 +30,6 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
) )
type LlamaServer interface { type LlamaServer interface {
@@ -55,15 +54,8 @@ type llmServer struct {
options api.Options options api.Options
numParallel int numParallel int
modelPath string modelPath string
modelLock sync.Mutex // Temporary until we switch fully to Go server
// llamaModel is an instance of the cgo llama.cpp model definition model *llama.Model // If non-nil, the runner is a new Go server
// nil if this server is running the new engine
llamaModel *llama.Model
llamaModelLock sync.Mutex
// textProcessor handles text encoding/decoding for the model in the Ollama engine
// nil if this server is running the llama.cpp based engine
textProcessor model.TextProcessor
estimate MemoryEstimate estimate MemoryEstimate
totalLayers uint64 totalLayers uint64
@@ -97,7 +89,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs // NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family. // The gpu list must be a single family.
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
systemInfo := discover.GetSystemInfo() systemInfo := discover.GetSystemInfo()
systemTotalMemory := systemInfo.System.TotalMemory systemTotalMemory := systemInfo.System.TotalMemory
systemFreeMemory := systemInfo.System.FreeMemory systemFreeMemory := systemInfo.System.FreeMemory
@@ -138,7 +130,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
slog.Info("offload", "", estimate) slog.Info("offload", "", estimate)
params := []string{ params := []string{
"--model", modelPath, "--model", model,
"--ctx-size", strconv.Itoa(opts.NumCtx), "--ctx-size", strconv.Itoa(opts.NumCtx),
"--batch-size", strconv.Itoa(opts.NumBatch), "--batch-size", strconv.Itoa(opts.NumBatch),
} }
@@ -161,6 +153,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
} }
} }
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
defaultThreads := systemInfo.GetOptimalThreadCount() defaultThreads := systemInfo.GetOptimalThreadCount()
if opts.NumThread > 0 { if opts.NumThread > 0 {
params = append(params, "--threads", strconv.Itoa(opts.NumThread)) params = append(params, "--threads", strconv.Itoa(opts.NumThread))
@@ -260,34 +257,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
} }
} }
slog.Debug("compatible gpu libraries", "compatible", compatible) slog.Debug("compatible gpu libraries", "compatible", compatible)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var llamaModel *llama.Model
var textProcessor model.TextProcessor
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
textProcessor, err = model.NewTextProcessor(modelPath)
if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
}
}
if len(projectors) > 0 && llamaModel != nil {
params = append(params, "--mmproj", projectors[0])
}
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc. // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running // adding each library's respective path to the LD_LIBRARY_PATH, until finally running
@@ -306,9 +275,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
} }
finalParams := []string{"runner"} finalParams := []string{"runner"}
if textProcessor != nil { if envconfig.NewEngine() {
// New engine
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
finalParams = append(finalParams, "--ollama-engine") finalParams = append(finalParams, "--ollama-engine")
} }
finalParams = append(finalParams, params...) finalParams = append(finalParams, params...)
@@ -348,20 +315,28 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// finally, add the root library path // finally, add the root library path
libraryPaths = append(libraryPaths, discover.LibOllamaPath) libraryPaths = append(libraryPaths, discover.LibOllamaPath)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
s := &llmServer{ s := &llmServer{
port: port, port: port,
cmd: exec.Command(exe, finalParams...), cmd: exec.Command(exe, finalParams...),
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
modelPath: modelPath, modelPath: model,
llamaModel: llamaModel, estimate: estimate,
textProcessor: textProcessor, numParallel: numParallel,
estimate: estimate, sem: semaphore.NewWeighted(int64(numParallel)),
numParallel: numParallel, totalLayers: f.KV().BlockCount() + 1,
sem: semaphore.NewWeighted(int64(numParallel)), gpus: gpus,
totalLayers: f.KV().BlockCount() + 1, done: make(chan error, 1),
gpus: gpus,
done: make(chan error, 1),
} }
s.cmd.Env = os.Environ() s.cmd.Env = os.Environ()
@@ -430,9 +405,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
} }
err := fmt.Errorf("error starting runner: %v %s", err, msg) err := fmt.Errorf("error starting runner: %v %s", err, msg)
if len(compatible) == 0 { if len(compatible) == 0 {
if llamaModel != nil {
llama.FreeModel(llamaModel)
}
return nil, err return nil, err
} }
@@ -961,25 +933,64 @@ type TokenizeResponse struct {
} }
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
s.llamaModelLock.Lock() s.modelLock.Lock()
defer s.llamaModelLock.Unlock() defer s.modelLock.Unlock()
if s.model != nil {
return s.model.Tokenize(content, false, true)
}
if s.llamaModel != nil { // Make sure the server is ready
return s.llamaModel.Tokenize(content, false, true) status, err := s.getServerStatus(ctx)
if err != nil {
return nil, err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if s.textProcessor != nil {
tokens, err := s.textProcessor.Encode(content, false) data, err := json.Marshal(TokenizeRequest{Content: content})
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("marshaling encode data: %w", err)
}
toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
} }
// not reached
return nil, fmt.Errorf("no tokenizer configured") req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("encode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do encode request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
if s.model == nil {
slog.Debug("new runner detected, loading model for cgo tokenization")
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
}
s.model = m
}
return s.model.Tokenize(content, false, true)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read encode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var encoded TokenizeResponse
if err := json.Unmarshal(body, &encoded); err != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err)
}
return encoded.Tokens, nil
} }
type DetokenizeRequest struct { type DetokenizeRequest struct {
@@ -991,38 +1002,80 @@ type DetokenizeResponse struct {
} }
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
s.llamaModelLock.Lock() s.modelLock.Lock()
defer s.llamaModelLock.Unlock() defer s.modelLock.Unlock()
if s.model != nil {
if s.llamaModel != nil {
var resp string var resp string
for _, token := range tokens { for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token) resp += s.model.TokenToPiece(token)
} }
return resp, nil return resp, nil
} }
if s.textProcessor != nil { // Make sure the server is ready
toks := make([]int32, len(tokens)) status, err := s.getServerStatus(ctx)
for i, t := range tokens { if err != nil {
toks[i] = int32(t) return "", err
} } else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
content, err := s.textProcessor.Decode(toks) return "", fmt.Errorf("unexpected server status: %s", status.ToString())
if err != nil {
return "", err
}
return content, nil
} }
// not reached
return "", fmt.Errorf("no tokenizer configured") data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return "", fmt.Errorf("decode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("do decode request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
if s.model == nil {
slog.Debug("new runner detected, loading model for cgo tokenization")
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return "", err
}
s.model = m
}
var resp string
for _, token := range tokens {
resp += s.model.TokenToPiece(token)
}
return resp, nil
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read decode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm decode error: %s", body)
return "", fmt.Errorf("%s", body)
}
var decoded DetokenizeResponse
if err := json.Unmarshal(body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err)
}
return decoded.Content, nil
} }
func (s *llmServer) Close() error { func (s *llmServer) Close() error {
s.llamaModelLock.Lock() s.modelLock.Lock()
if s.llamaModel != nil { if s.model != nil {
llama.FreeModel(s.llamaModel) llama.FreeModel(s.model)
s.llamaModel = nil s.model = nil
} }
s.llamaModelLock.Unlock() s.modelLock.Unlock()
if s.cmd != nil { if s.cmd != nil {
slog.Debug("stopping llama server") slog.Debug("stopping llama server")

View File

@@ -1,40 +0,0 @@
package logging
import (
"context"
"log/slog"
"os"
)
const LevelTrace slog.Level = slog.LevelDebug - 4
type Logger struct {
logger *slog.Logger
}
func NewLogger() *Logger {
handler := slog.NewTextHandler(os.Stdout, nil)
return &Logger{
logger: slog.New(handler),
}
}
func (l *Logger) Trace(msg string, args ...any) {
l.logger.Log(context.Background(), LevelTrace, msg, args...)
}
func (l *Logger) Debug(msg string, args ...any) {
l.logger.Debug(msg, args...)
}
func (l *Logger) Info(msg string, args ...any) {
l.logger.Info(msg, args...)
}
func (l *Logger) Warn(msg string, args ...any) {
l.logger.Warn(msg, args...)
}
func (l *Logger) Error(msg string, args ...any) {
l.logger.Error(msg, args...)
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
"slices"
"strconv" "strconv"
"strings" "strings"
) )
@@ -19,43 +18,13 @@ type Config interface {
Strings(string, ...[]string) []string Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32 Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32
} }
type Backend interface { type Backend interface {
Config() Config Config() Config
Get(name string) Tensor Get(name string) Tensor
NewContext() Context NewContext() Context
NewContextSize(size int) Context SystemInfo() string
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
} }
// BackendParams controls how the backend loads and executes models // BackendParams controls how the backend loads and executes models
@@ -71,9 +40,6 @@ type BackendParams struct {
// TensorSplit is the fraction of the model to offload to each GPU // TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32 TensorSplit []float32
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
} }
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error)) var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
@@ -95,24 +61,14 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
} }
type Context interface { type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error)
Forward(...Tensor) Context Forward(...Tensor) Context
Compute(...Tensor) Compute(...Tensor)
MaxGraphNodes() int MaxTensors() int
Close() Close()
// Input returns a context appropriate for creating input tensors
Input() Context
// Output returns a context appropriate for creating output tensors
Output() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
} }
type Tensor interface { type Tensor interface {
@@ -135,10 +91,8 @@ type Tensor interface {
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor Scale(ctx Context, s float64) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor GELU(ctx Context) Tensor
@@ -148,7 +102,6 @@ type Tensor interface {
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor Contiguous(ctx Context) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor Unpad(ctx Context, shape ...int) Tensor
@@ -163,10 +116,6 @@ type Tensor interface {
// operation equivalent to following code on a tensor named // operation equivalent to following code on a tensor named
// query: // query:
// //
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query) // kq := key.MulmatFullPrec(ctx, query)
// //
// kq = kq.Scale(ctx, scale) // kq = kq.Scale(ctx, scale)
@@ -220,8 +169,8 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
}) })
case DTypeF16, DTypeQ80, DTypeQ40: case DTypeF16:
f32 := ctx.Empty(DTypeF32, t.Shape()...) f32 := ctx.Zeros(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32) f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
@@ -246,17 +195,16 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
} }
shape := t.Shape() shape := t.Shape()
slices.Reverse(shape)
var sb strings.Builder var sb strings.Builder
var f func([]int, int) var f func([]int, int)
f = func(dims []int, stride int) { f = func(dims []int, stride int) {
prefix := strings.Repeat(" ", len(shape)-len(dims)+1) prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
sb.WriteString("[") fmt.Fprint(&sb, "[")
defer func() { sb.WriteString("]") }() defer func() { fmt.Fprint(&sb, "]") }()
for i := 0; i < dims[0]; i++ { for i := 0; i < dims[0]; i++ {
if i >= items && i < dims[0]-items { if i >= items && i < dims[0]-items {
sb.WriteString("..., ") fmt.Fprint(&sb, "..., ")
// skip to next printable element // skip to next printable element
skip := dims[0] - 2*items skip := dims[0] - 2*items
if len(dims) > 1 { if len(dims) > 1 {
@@ -271,14 +219,9 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
} }
} else { } else {
text := fn(s[stride+i]) fmt.Fprint(&sb, fn(s[stride+i]))
if len(text) > 0 && text[0] != '-' {
sb.WriteString(" ")
}
sb.WriteString(text)
if i < dims[0]-1 { if i < dims[0]-1 {
sb.WriteString(", ") fmt.Fprint(&sb, ", ")
} }
} }
} }
@@ -294,7 +237,5 @@ const (
DTypeOther DType = iota DTypeOther DType = iota
DTypeF32 DTypeF32
DTypeF16 DTypeF16
DTypeQ80
DTypeQ40
DTypeI32 DTypeI32
) )

File diff suppressed because it is too large Load Diff

View File

@@ -114,7 +114,6 @@ extern "C" {
// get raw pointer to the first element of the array with the given key_id // get raw pointer to the first element of the array with the given key_id
// for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
GGML_API size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id);
// get ith C string from array with given key_id // get ith C string from array with given key_id
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);

View File

@@ -1,11 +0,0 @@
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
void ollama_debug(const struct ggml_tensor *tensor, bool verbose);
#ifdef __cplusplus
}
#endif

View File

@@ -484,29 +484,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
} }
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) { for (const auto & entry : dir_it) {
if (entry.is_regular_file()) { try {
std::string filename = entry.path().filename().string(); if (entry.is_regular_file()) {
std::string ext = entry.path().extension().string(); std::string filename = entry.path().filename().string();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { std::string ext = entry.path().extension().string();
dl_handle_ptr handle { dl_load_library(entry.path()) }; if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
if (!handle) { dl_handle_ptr handle { dl_load_library(entry.path()) };
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); if (!handle) {
continue; GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
} continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) { if (!score_fn) {
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
continue; continue;
} }
int s = score_fn(); int s = score_fn();
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
if (s > best_score) { if (s > best_score) {
best_score = s; best_score = s;
best_path = entry.path(); best_path = entry.path();
}
} }
} }
} catch (const std::exception & e) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
} }
} }
} }
@@ -529,14 +533,6 @@ void ggml_backend_load_all() {
ggml_backend_load_all_from_path(nullptr); ggml_backend_load_all_from_path(nullptr);
} }
static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
try {
ggml_backend_load_best(name, silent, user_search_path);
} catch (const std::exception & e) {
GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
}
}
void ggml_backend_load_all_from_path(const char * dir_path) { void ggml_backend_load_all_from_path(const char * dir_path) {
#ifdef NDEBUG #ifdef NDEBUG
bool silent = true; bool silent = true;
@@ -544,18 +540,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
bool silent = false; bool silent = false;
#endif #endif
ggml_backend_try_load_best("blas", silent, dir_path); ggml_backend_load_best("blas", silent, dir_path);
ggml_backend_try_load_best("cann", silent, dir_path); ggml_backend_load_best("cann", silent, dir_path);
ggml_backend_try_load_best("cuda", silent, dir_path); ggml_backend_load_best("cuda", silent, dir_path);
ggml_backend_try_load_best("hip", silent, dir_path); ggml_backend_load_best("hip", silent, dir_path);
ggml_backend_try_load_best("kompute", silent, dir_path); ggml_backend_load_best("kompute", silent, dir_path);
ggml_backend_try_load_best("metal", silent, dir_path); ggml_backend_load_best("metal", silent, dir_path);
ggml_backend_try_load_best("rpc", silent, dir_path); ggml_backend_load_best("rpc", silent, dir_path);
ggml_backend_try_load_best("sycl", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path);
ggml_backend_try_load_best("vulkan", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path);
ggml_backend_try_load_best("opencl", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path);
ggml_backend_try_load_best("musa", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path);
ggml_backend_try_load_best("cpu", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path);
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
const char * backend_path = std::getenv("GGML_BACKEND_PATH"); const char * backend_path = std::getenv("GGML_BACKEND_PATH");
if (backend_path) { if (backend_path) {

View File

@@ -1,6 +0,0 @@
//go:build debug
package cpu
// #cgo CPPFLAGS: -DOLLAMA_DEBUG
import "C"

View File

@@ -11,8 +11,6 @@
#include "ggml-threading.h" #include "ggml-threading.h"
#include "ggml.h" #include "ggml.h"
#include "ollama-debug.h"
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW #include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@@ -14105,10 +14103,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
#ifdef OLLAMA_DEBUG
ollama_debug(node, true);
#endif
if (state->ith == 0 && cplan->abort_callback && if (state->ith == 0 && cplan->abort_callback &&
cplan->abort_callback(cplan->abort_callback_data)) { cplan->abort_callback(cplan->abort_callback_data)) {
atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed); atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);

View File

@@ -7,20 +7,6 @@ package ggml
// #include <stdlib.h> // #include <stdlib.h>
// #include "ggml-backend.h" // #include "ggml-backend.h"
// extern void sink(int level, char *text, void *user_data); // extern void sink(int level, char *text, void *user_data);
// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); }
// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; }
/*
typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER;
static COMPILER compiler_name(void) {
#if defined(__clang__)
return COMPILER_CLANG;
#elif defined(__GNUC__)
return COMPILER_GNUC;
#else
return COMPILER_UNKNOWN;
#endif
}
*/
import "C" import "C"
import ( import (
@@ -30,7 +16,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
@@ -105,43 +90,4 @@ var OnceLoad = sync.OnceFunc(func() {
visited[abspath] = struct{}{} visited[abspath] = struct{}{}
} }
} }
slog.Info("system", "", system{})
}) })
type system struct{}
func (system) LogValue() slog.Value {
var attrs []slog.Attr
names := make(map[string]int)
for i := range C.ggml_backend_dev_count() {
r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i))
func() {
fName := C.CString("ggml_backend_get_features")
defer C.free(unsafe.Pointer(fName))
if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil {
var features []any
for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) {
features = append(features, C.GoString(f.name), C.GoString(f.value))
}
name := C.GoString(C.ggml_backend_reg_name(r))
attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...))
names[name] += 1
}
}()
}
switch C.compiler_name() {
case C.COMPILER_CLANG:
attrs = append(attrs, slog.String("compiler", "cgo(clang)"))
case C.COMPILER_GNUC:
attrs = append(attrs, slog.String("compiler", "cgo(gcc)"))
default:
attrs = append(attrs, slog.String("compiler", "cgo(unknown)"))
}
return slog.GroupValue(attrs...)
}

View File

@@ -777,14 +777,10 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id
const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data(); return ctx->kv[key_id].data.data();
} }
size_t gguf_get_arr_data_n(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
return ctx->kv[key_id].data.size();
}
const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
@@ -878,6 +874,7 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
return ctx->kv[key_id].data.data(); return ctx->kv[key_id].data.data();
} }

View File

@@ -1,116 +0,0 @@
#include <string.h>
#include <inttypes.h>
#include "ollama-debug.h"
static int mul(int64_t *dims, int ndims) {
int result = 1;
for (int i = 0; i < ndims; i++) {
result *= dims[i];
}
return result;
}
static void repeat(char c, int n) {
for (int i = 0; i < n; i++) {
fprintf(stderr, "%c", c);
}
}
static void print_tensor(const void *tensor, void (*cb)(const void *, int),
int shape,
int64_t *dims, int ndims, int stride,
int nitems, int pad) {
fprintf(stderr, "[");
for (int i = 0; i < dims[0]; i++) {
if (i >= nitems && i < dims[0] - nitems) {
fprintf(stderr, "... (%" PRIi64 " more), ", dims[0] - 2 * nitems);
int skip = dims[0] - 2 * nitems;
if (ndims > 1) {
stride += mul(dims + 1, ndims - 1) * skip;
repeat('\n', ndims - 1);
repeat(' ', shape - ndims + 1 + pad);
}
i += skip - 1;
} else if (ndims > 1) {
print_tensor(tensor, cb, shape, dims + 1, ndims - 1, stride,
nitems, pad);
stride += mul(dims + 1, ndims - 1);
if (i < dims[0] - 1) {
fprintf(stderr, ", ");
repeat('\n', ndims - 1);
repeat(' ', shape - ndims + 1 + pad);
}
} else {
cb(tensor, stride + i);
if (i < dims[0] - 1) {
fprintf(stderr, ", ");
}
}
}
fprintf(stderr, "]");
}
static void print_tensor_f16(const void *tensor, int i) {
float value = ggml_fp16_to_fp32(((const ggml_fp16_t *)tensor)[i]);
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
}
static void print_tensor_f32(const void *tensor, int i) {
float value = ((const float *)tensor)[i];
fprintf(stderr, "%s%f", value < 0 ? "" : " ", value);
}
static void print_tensor_i32(const void *tensor, int i) {
int32_t value = ((const int32_t *)tensor)[i];
fprintf(stderr, "%s%d", value < 0 ? "" : " ", value);
}
static void ollama_debug_tensor(const struct ggml_tensor *tensor, bool verbose, const char *prefix, int indent) {
fprintf(stderr, "%s%s %s (%s): [%" PRIi64 " %" PRIi64 " %" PRIi64 " %" PRIi64 "]\n", prefix, tensor->name,
ggml_op_name(tensor->op), ggml_type_name(tensor->type), tensor->ne[0],
tensor->ne[1], tensor->ne[2], tensor->ne[3]);
if (!verbose) {
return;
}
for (int i = 0; i < indent; i++) {
fprintf(stderr, " ");
}
switch (tensor->type) {
case GGML_TYPE_F16:
print_tensor(ggml_get_data(tensor), print_tensor_f16, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
case GGML_TYPE_F32:
print_tensor(ggml_get_data(tensor), print_tensor_f32, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
case GGML_TYPE_I32:
print_tensor(ggml_get_data(tensor), print_tensor_i32, ggml_n_dims(tensor),
(int64_t *)tensor->ne, ggml_n_dims(tensor), 0, 3, indent);
break;
default:
fprintf(stderr, "<unsupported type>\n");
return;
}
fprintf(stderr, "\n");
}
void ollama_debug(const struct ggml_tensor *tensor, bool verbose) {
ollama_debug_tensor(tensor, verbose, ">>> ", 4);
for (int i = 0; i < GGML_MAX_SRC && tensor->src[i] != NULL; ++i) {
char src[8];
const int n = snprintf(src, sizeof(src), " src%d ", i);
if (n >= sizeof(src)) {
src[sizeof(src) - 1] = '\0';
}
ollama_debug_tensor(tensor->src[i], verbose, src, 4);
}
}

View File

@@ -1,7 +0,0 @@
//go:build !debug
package ggml
func Threads(n int) int {
return n
}

View File

@@ -1,7 +0,0 @@
//go:build debug
package ggml
func Threads(_ int) int {
return 1
}

View File

@@ -3,7 +3,6 @@ package nn
import ( import (
"fmt" "fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
) )
@@ -12,50 +11,40 @@ import (
// //
// Parameters: // Parameters:
// - ctx: Context for tensor operations // - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q] // - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only // - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only // - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
// //
// Returns: // Returns:
// //
// Attention output with shape [d_v, heads, seq_len_q] // Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if key != nil && value != nil { if query.Dim(0) != key.Dim(0) {
if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
} }
var mask ml.Tensor if mask != nil && query.Dim(1) != mask.Dim(1) {
if cache != nil { panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
key, value, mask = cache.Get(ctx)
} }
// Only use the fast SDPA implementation if we have a cache, since that's what if key.Dim(1) != value.Dim(0) {
// will do any expected backend-specific transformations for us panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { }
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else { } else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query) kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale) kq = kq.Scale(ctx, scale)

View File

@@ -1,37 +0,0 @@
package input
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Multimodal []MultimodalIndex
Positions []int32
Sequences []int
Outputs []int32
}

View File

@@ -3,6 +3,7 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"image"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog" "log/slog"
@@ -15,52 +16,23 @@ import (
_ "golang.org/x/image/tiff" _ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/model/input"
) )
var ErrNoVisionModel = errors.New("this model is missing data required for image input") // Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Positions []int32
Sequences []int
Outputs []int32
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration Images []image.Image
type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
} }
// MultimodalProcessor must be implemented by multimodal models. type config struct {
type MultimodalProcessor interface { Cache kvcache.Cache
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
//
// The input is a slice of tokens with the results of EncodeMultimodal interleaved
// in the order that the user provided them. Each element of the slice will be
// either a single token or single multimodal object.
//
// The model must ensure that inputs are stored according to how they will be
// processed and stored in the cache. For example, Llava-style models should insert
// placeholder tokens equal to the feature size of the corresponding image with
// the image itself attached to and split across these tokens. When Forward is called
// a partial subset of these tokens may be submitted according to the batch size.
//
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
} }
// Base implements the common fields and methods for all models // Base implements the common fields and methods for all models
@@ -69,10 +41,6 @@ type Base struct {
config config
} }
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model // Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend { func (m *Base) Backend() ml.Backend {
return m.b return m.b
@@ -82,6 +50,14 @@ func (m *Base) Config() config {
return m.config return m.config
} }
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
}
var models = make(map[string]func(ml.Config) (Model, error)) var models = make(map[string]func(ml.Config) (Model, error))
// Register registers a model constructor for the given architecture // Register registers a model constructor for the given architecture
@@ -124,36 +100,6 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil return m, nil
} }
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, _, err := fs.Decode(r, -1)
if err != nil {
return nil, err
}
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, fmt.Errorf("%v is not a TextProcessor", m)
}
return tp, nil
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type() t := v.Type()
@@ -280,7 +226,7 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) { if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
} }
@@ -291,7 +237,7 @@ func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, opts) err := cache.StartForward(ctx, opts.Positions, opts.Sequences)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,15 +3,12 @@ package model
import ( import (
"reflect" "reflect"
"slices" "slices"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
) )
func TestParseTags(t *testing.T) { func TestParseTags(t *testing.T) {
@@ -137,40 +134,3 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
} }
} }
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fs.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
}
models["dummy"] = func(ml.Config) (Model, error) {
return notTextProcessorModel{}, nil
}
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
}
}
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
panic("unimplemented")
}
func (notTextProcessorModel) Backend() ml.Backend {
panic("unimplemented")
}
func (notTextProcessorModel) Config() config {
panic("unimplemented")
}

View File

@@ -1,220 +0,0 @@
package gemma2
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
attnLogitSoftcap float32
finalLogitSoftcap float32
largeModelScaling bool
}
type Model struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"` // just set to token_embd?
*Options
}
const (
gemma27BLayerCount = 46
)
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
m.Cache.SetConfig(ml.CacheConfig{})
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.Options.largeModelScaling = true
}
for i, layer := range m.Layers {
cacheType := i % 2
m.Cache.SetLayer(i)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("gemma2", New)
}

View File

@@ -1,177 +0,0 @@
package gemma3
import (
"bytes"
"encoding/binary"
"hash/fnv"
"image"
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.SentencePieceModel
*VisionModel `gguf:"v,vision"`
*TextModel
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
InputProjection *nn.Linear `gguf:"mm_input_projection"`
tokensPerImage int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
l := visionOutputs.Dim(0)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
patchesPerImage := imageSize / patchSize
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
return visionOutputs
}
func New(c ml.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
MultiModalProjector: &MultiModalProjector{
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
},
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return visionOutputs, nil
}
type imageToken struct {
embedding ml.Tensor
index int
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var result []input.Input
fnvHash := fnv.New64a()
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
imageInputs := []input.Input{
{Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>""
}
result = append(result, imageInputs...)
// add image embeddings
inputMultimodal := inp.Multimodal.(ml.Tensor)
for i := range inputMultimodal.Dim(1) {
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash)
fnvHash.Write([]byte{byte(i)})
imageToken := imageToken{embedding: inputMultimodal, index: i}
result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()})
}
result = append(result,
input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n"
)
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
}
func init() {
model.Register("gemma3", New)
}

View File

@@ -1,247 +0,0 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
}
type TextModel struct {
model.Base
model.SentencePieceModel
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
const (
gemmaGlobalCacheCount = 6
gemma27BLayerCount = 62
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
func newTextModel(c ml.Config) *TextModel {
numBlocks := int(c.Uint("block_count"))
m := TextModel{
SentencePieceModel: model.NewSentencePieceModel(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Uints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
},
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
return &m
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
} else {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
}
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
}
type TextMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positionIDs, cache, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
return hiddenState.Add(ctx, residual)
}
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
var embedding ml.Tensor
var src, dst, length int
var except []int
for _, image := range multimodal {
imageToken := image.Multimodal.(imageToken)
imageSrc := imageToken.index
imageDst := image.Index
if embedding == nil {
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
} else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst {
src = imageSrc
dst = imageDst
length++
} else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst {
length++
} else {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
embedding = imageToken.embedding
src = imageSrc
dst = imageDst
length = 1
}
except = append(except, imageDst)
}
if embedding != nil {
visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0))
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0))))
}
return except
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
for i, layer := range m.Layers {
// gemma alternates between the sliding window (local) and causal (global)
// kv cache every 6 layers
cacheType := cacheTypeSWA
if (i+1)%gemmaGlobalCacheCount == 0 {
cacheType = cacheTypeCausal
}
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -1,127 +0,0 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
return hiddenState
}
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
SelfAttention *VisionSelfAttention
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
eps float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := make([]int32, numPatches)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
}
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
},
}
}

View File

@@ -1,58 +0,0 @@
package gemma3
import (
"image"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize, patchSize, numChannels int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
numChannels: int(c.Uint("vision.num_channels")),
}
}
func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
var pixelVals, rVals, gVals, bVals []float32
bounds := img.Bounds()
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil
}

View File

@@ -1,18 +1,16 @@
package llama package llama
import ( import (
"fmt"
"math" "math"
"strings"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type Options struct { type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32
@@ -31,10 +29,6 @@ type Model struct {
} }
func New(c ml.Config) (model.Model, error) { func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -66,38 +60,43 @@ func New(c ml.Config) (model.Model, error) {
} }
type SelfAttention struct { type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv) return sa.Output.Forward(ctx, kqv)
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
} }
type MLP struct { type MLP struct {
@@ -139,18 +138,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,18 +1,10 @@
package mllama package mllama
import ( import (
"bytes"
"encoding/binary"
"fmt"
"hash/fnv"
"image"
"slices"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type Model struct { type Model struct {
@@ -33,10 +25,6 @@ const (
) )
func New(c ml.Config) (model.Model, error) { func New(c ml.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")
}
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -55,103 +43,59 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c), TextModel: newTextModel(c),
} }
encoderCache := kvcache.NewEncoderCache() m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil return &m, nil
} }
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
if inputs[i].Multimodal == nil {
if len(images) > 0 {
inputs[i].Multimodal = images[0].Multimodal
inputs[i].MultimodalHash = images[0].MultimodalHash
for j := 1; j < len(images); j++ {
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
fnvHash.Reset()
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
inputs[i].MultimodalHash = fnvHash.Sum64()
}
images = nil
}
} else {
images = append(images, inputs[i])
inputs[i].Token = -1
}
}
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 { if opts.Images != nil {
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
if err != nil {
return nil, err
}
pixelValues, err := ctx.FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
} }
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -10,31 +10,36 @@ import (
) )
type TextSelfAttention struct { type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"` Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"` Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
@@ -42,11 +47,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
return key, nil
} }
type TextMLP struct { type TextMLP struct {
@@ -106,7 +107,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps) query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value ml.Tensor var key, value, mask ml.Tensor
if crossAttentionStates != nil { if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -118,23 +119,16 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value) cache.Put(ctx, key, value)
} else {
key, value, mask = cache.Get(ctx)
} }
key, value, _ = cache.Get(ctx) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention) return ca.Output.Forward(ctx, attention)
@@ -197,6 +191,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
} }
type TextModelOptions struct { type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
ropeDim uint32 ropeDim uint32

View File

@@ -144,6 +144,8 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
return images return images
} }
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image // remove the "alpha" channel by drawing over a prefilled image
// //
//nolint:unused //nolint:unused

View File

@@ -1,8 +1,6 @@
package models package models
import ( import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama" _ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/mllama"
) )

View File

@@ -4,7 +4,6 @@ import (
"cmp" "cmp"
"iter" "iter"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"sync" "sync"
@@ -19,17 +18,8 @@ const (
SpecialEOS SpecialEOS
) )
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface { type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error) Encode(string) ([]int32, error)
Decode([]int32) (string, error) Decode([]int32) (string, error)
Is(int32, Special) bool Is(int32, Special) bool
} }
@@ -37,11 +27,11 @@ type TextProcessor interface {
type Vocabulary struct { type Vocabulary struct {
Values []string Values []string
Types []uint32 Types []uint32
Scores []float32 Scores []uint32
Merges []string Merges []string
BOS, EOS, EOT int32 BOS, EOS int32
AddBOS, AddEOS, AddEOT bool AddBOS, AddEOS bool
specialOnce sync.Once specialOnce sync.Once
special []string special []string
@@ -58,7 +48,7 @@ func (v *Vocabulary) Is(id int32, special Special) bool {
case SpecialBOS: case SpecialBOS:
return id == v.BOS return id == v.BOS
case SpecialEOS: case SpecialEOS:
return id == v.EOS || id == v.EOT return id == v.EOS
default: default:
return false return false
} }
@@ -86,9 +76,7 @@ func (v *Vocabulary) Decode(id int32) string {
func (v *Vocabulary) SpecialVocabulary() []string { func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() { v.specialOnce.Do(func() {
for i := range v.Values { for i := range v.Values {
if slices.Contains([]int{105, 106}, i) { if v.Types[i] == 3 {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i]) v.special = append(v.special, v.Values[i])
} }
} }
@@ -156,7 +144,7 @@ type merge struct {
runes []rune runes []rune
} }
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
fragments := []fragment{{value: s}} fragments := []fragment{{value: s}}
for _, special := range bpe.vocab.SpecialVocabulary() { for _, special := range bpe.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently // TODO: process special tokens concurrently
@@ -189,6 +177,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
for _, frag := range fragments { for _, frag := range fragments {
if len(frag.ids) > 0 { if len(frag.ids) > 0 {
ids = append(ids, frag.ids...) ids = append(ids, frag.ids...)
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
continue continue
} }
@@ -212,6 +201,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
// short circuit if the fragment is in the vocabulary // short circuit if the fragment is in the vocabulary
if id := bpe.vocab.Encode(sb.String()); id >= 0 { if id := bpe.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
continue continue
} }
@@ -285,13 +275,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
// TODO: handle the edge case where the rune isn't in the vocabulary // TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id) ids = append(ids, id)
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
} }
} }
} }
} }
} }
if addSpecial && len(ids) > 0 { if len(ids) > 0 {
if bpe.vocab.AddBOS { if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS { if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
@@ -338,5 +329,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
} }
} }
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil return sb.String(), nil
} }

View File

@@ -1,249 +0,0 @@
package model
import (
"iter"
"strings"
"github.com/dlclark/regexp2"
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
"github.com/ollama/ollama/logging"
)
const spmWhitespaceSep = "▁"
var log = logging.NewLogger()
func replaceWhitespaceBySeperator(s string) string {
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
}
type SentencePieceModel struct {
maxTokenLen int
pre *regexp2.Regexp
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
for cnt := range vocab.Types {
switch vocab.Types[cnt] {
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
fallthrough
default:
counter[int(vocab.Types[cnt])] += 1
}
}
log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
maxTokenLen: maxTokenLen,
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
if !yield(m.String()) {
break
}
}
}
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := spm.vocab.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
log.Trace("fragments", "frags", fragments)
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
continue
}
for split := range spm.split(frag.value) {
split = replaceWhitespaceBySeperator(split)
var sb strings.Builder
sb.Write([]byte(split))
if id := spm.vocab.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
continue
}
runes := []rune(sb.String())
pq := queue.NewWith(func(a, b any) int {
priA := a.(*candidate)
priB := b.(*candidate)
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
return -1
}
return 1
})
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
log.Trace("tokenizer", "merges", merges)
pairwise := func(a, b int) *candidate {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
if id := spm.vocab.Encode(left + right); id >= 0 {
return &candidate{
a: a,
b: b,
score: spm.vocab.Scores[id],
}
}
return nil
}
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pq.Enqueue(pair)
}
}
pqv := pq.Values()
for _, v := range pqv {
e := v.(*candidate)
log.Trace("candidate", "candidate", e)
}
for !pq.Empty() {
v, _ := pq.Dequeue()
pair := v.(*candidate)
left, right := merges[pair.a], merges[pair.b]
log.Trace("pair", "left", left, "right", right)
if len(left.runes) == 0 || len(right.runes) == 0 {
continue
}
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pq.Enqueue(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pq.Enqueue(pair)
}
}
log.Trace("merges", "merges", merges)
for _, merge := range merges {
if len(merge.runes) > 0 {
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
} else {
log.Error("missing token", "token", string(merge.runes))
}
}
}
}
}
if addSpecial && len(ids) > 0 {
if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS {
log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
}
log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...)
}
if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS {
log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
}
log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS)
}
}
return ids, nil
}
type candidate struct {
a, b int
score float32
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
if _, err := sb.WriteString(data); err != nil {
return "", err
}
}
log.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

View File

@@ -1,118 +0,0 @@
package model
import (
"log/slog"
"os"
"path/filepath"
"slices"
"testing"
"google.golang.org/protobuf/proto"
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
if err != nil {
t.Fatal(err)
}
var spm sentencepiece.ModelProto
if err := proto.Unmarshal(bts, &spm); err != nil {
t.Fatal(err)
}
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
var v Vocabulary
for _, piece := range spm.GetPieces() {
v.Values = append(v.Values, piece.GetPiece())
v.Scores = append(v.Scores, piece.GetScore())
switch t := piece.GetType(); t {
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
sentencepiece.ModelProto_SentencePiece_CONTROL,
sentencepiece.ModelProto_SentencePiece_UNUSED,
sentencepiece.ModelProto_SentencePiece_BYTE:
v.Types = append(v.Types, uint32(t))
default:
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// todo parse the special tokens file
// - this will roundtrip correctly but the <start_of_turn> and
// <end_of_turn> tokens aren't processed
v.Types = append(v.Types, tt)
}
}
return NewSentencePieceModel(preTokenizer, &v)
}
func TestSentencePieceEncode(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
tokenizer := loadSentencePieceVocab(t)
t.Run("basic roundtrip", func(t *testing.T) {
t.Parallel()
cases := []string{
"hello",
"hello ",
"hello ",
" hello",
" hello ",
" hello ",
"hello world",
"请考试我的软件12345",
"你好",
"Hello 你好 world!",
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
"Numbers and symbols: 123456789 +- */",
"Special tokens: <bos> text <eos>",
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
}
for _, want := range cases {
ids, err := tokenizer.Encode(want, true)
if err != nil {
t.Fatal(err)
}
if got, err := tokenizer.Decode(ids); err != nil {
t.Fatal(err)
} else if got != want {
t.Errorf("got %q, want %q [%#v]", got, want, ids)
}
}
})
t.Run("special tokens", func(t *testing.T) {
type candidate struct {
token string
ids []int32
}
cases := []candidate{
{"<bos>", []int32{2}},
{"<eos>", []int32{1}},
}
for _, want := range cases {
ids, err := tokenizer.Encode(want.token, true)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(ids, want.ids) {
t.Errorf("got %#v, want %#v", ids, want.ids)
}
}
})
}

View File

@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
t.Parallel() t.Parallel()
ids, err := tokenizer.Encode("hello world", true) ids, err := tokenizer.Encode("hello world")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
t.Errorf("got %q, want hello world", s) t.Errorf("got %q, want hello world", s)
} }
ids, err = tokenizer.Encode("hello <|end_of_text|>", true) ids, err = tokenizer.Encode("hello <|end_of_text|>")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
} }
for s, want := range cases { for s, want := range cases {
ids, err := tokenizer.Encode(s, true) ids, err := tokenizer.Encode(s)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
} }
for _, want := range cases { for _, want := range cases {
ids, err := tokenizer.Encode(want, true) ids, err := tokenizer.Encode(want)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
} }
for s, want := range cases { for s, want := range cases {
ids, err := tokenizer.Encode(s, true) ids, err := tokenizer.Encode(s)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for range b.N { for range b.N {
_, err := tokenizer.Encode(string(bts), true) _, err := tokenizer.Encode(string(bts))
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}) })
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
ids, err := tokenizer.Encode(string(bts), true) ids, err := tokenizer.Encode(string(bts))
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

Binary file not shown.

View File

@@ -116,9 +116,19 @@ func (i *Instance) Readline() (string, error) {
switch r { switch r {
case KeyUp: case KeyUp:
i.historyPrev(buf, &currentLineBuf) if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
case KeyDown: case KeyDown:
i.historyNext(buf, &currentLineBuf) if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(currentLineBuf)
}
}
case KeyLeft: case KeyLeft:
buf.MoveLeft() buf.MoveLeft()
case KeyRight: case KeyRight:
@@ -175,10 +185,6 @@ func (i *Instance) Readline() (string, error) {
esc = true esc = true
case CharInterrupt: case CharInterrupt:
return "", ErrInterrupt return "", ErrInterrupt
case CharPrev:
i.historyPrev(buf, &currentLineBuf)
case CharNext:
i.historyNext(buf, &currentLineBuf)
case CharLineStart: case CharLineStart:
buf.MoveToStart() buf.MoveToStart()
case CharLineEnd: case CharLineEnd:
@@ -240,24 +246,6 @@ func (i *Instance) HistoryDisable() {
i.History.Enabled = false i.History.Enabled = false
} }
func (i *Instance) historyPrev(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos > 0 {
if i.History.Pos == i.History.Size() {
*currentLineBuf = []rune(buf.String())
}
buf.Replace([]rune(i.History.Prev()))
}
}
func (i *Instance) historyNext(buf *Buffer, currentLineBuf *[]rune) {
if i.History.Pos < i.History.Size() {
buf.Replace([]rune(i.History.Next()))
if i.History.Pos == i.History.Size() {
buf.Replace(*currentLineBuf)
}
}
}
func NewTerminal() (*Terminal, error) { func NewTerminal() (*Terminal, error) {
fd := os.Stdin.Fd() fd := os.Stdin.Fd()
termios, err := SetRawMode(fd) termios, err := SetRawMode(fd)

View File

@@ -931,6 +931,7 @@ func Execute(args []string) error {
slog.Info("starting go runner") slog.Info("starting go runner")
llama.BackendInit() llama.BackendInit()
slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
server := &Server{ server := &Server{
batchSize: *batchSize, batchSize: *batchSize,

View File

@@ -5,12 +5,12 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"reflect"
"time" "time"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type InputCache struct { type InputCache struct {
@@ -39,7 +39,10 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
slots := make([]InputCacheSlot, numSlots) slots := make([]InputCacheSlot, numSlots)
for i := range slots { for i := range slots {
slots[i] = InputCacheSlot{Id: i} slots[i] = InputCacheSlot{
Id: i,
Inputs: make([]input, 0),
}
} }
cache := model.Config().Cache cache := model.Config().Cache
@@ -59,9 +62,9 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
func kvCacheTypeFromStr(s string) ml.DType { func kvCacheTypeFromStr(s string) ml.DType {
switch s { switch s {
case "q8_0": case "q8_0":
return ml.DTypeQ80 panic("kv cache quantization not yet implemented")
case "q4_0": case "q4_0":
return ml.DTypeQ40 panic("kv cache quantization not yet implemented")
default: default:
return ml.DTypeF16 return ml.DTypeF16
} }
@@ -80,7 +83,7 @@ type InputCacheSlot struct {
Id int Id int
// Inputs that are stored in the KV cache // Inputs that are stored in the KV cache
Inputs []input.Input Inputs []input
// is this cache actively being processed as part of a sequence? // is this cache actively being processed as part of a sequence?
InUse bool InUse bool
@@ -89,7 +92,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@@ -140,7 +143,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
return slot, prompt, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
longest := int32(-1) longest := int32(-1)
var longestSlot *InputCacheSlot var longestSlot *InputCacheSlot
@@ -163,7 +166,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
return longestSlot, longest, nil return longestSlot, longest, nil
} }
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
oldest := time.Now() oldest := time.Now()
var oldestSlot *InputCacheSlot var oldestSlot *InputCacheSlot
@@ -199,7 +202,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot { if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs)) len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input.Input, longest) oldestSlot.Inputs = make([]input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil { if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -209,7 +212,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil return oldestSlot, longest, nil
} }
func countCommonPrefix(a []input.Input, b []input.Input) int32 { func countCommonPrefix(a []input, b []input) int32 {
var count int32 var count int32
for i := range a { for i := range a {
@@ -217,7 +220,7 @@ func countCommonPrefix(a []input.Input, b []input.Input) int32 {
break break
} }
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash { if !reflect.DeepEqual(a[i], b[i]) {
break break
} }

View File

@@ -4,8 +4,6 @@ import (
"image" "image"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/model/input"
) )
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
@@ -15,50 +13,44 @@ func TestCountCommon(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
t1 []input.Input t1 []input
t2 []input.Input t2 []input
expected int32 expected int32
}{ }{
{ {
name: "Equal", name: "Equal",
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t1: []input{{token: 1}, {token: 2}, {token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 3, expected: 3,
}, },
{ {
name: "Prefix", name: "Prefix",
t1: []input.Input{{Token: 1}}, t1: []input{{token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, t1: []input{{image: imgA}},
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, t2: []input{{image: imgA}, {image: imgB}, {image: imgC}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input{{token: 1}, {image: imgA}},
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, t2: []input{{token: 1}, {image: imgA}, {token: 5}},
expected: 2, expected: 2,
}, },
{
name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{ {
name: "Empty", name: "Empty",
t1: []input.Input{}, t1: []input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []input{{token: 1}, {token: 2}, {token: 3}},
expected: 0, expected: 0,
}, },
{ {
name: "Both Empty", name: "Both Empty",
t1: []input.Input{}, t1: []input{},
t2: []input.Input{}, t2: []input{},
expected: 0, expected: 0,
}, },
} }
@@ -82,7 +74,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []input.Input prompt []input
longest expected longest expected
best expected best expected
}{ }{
@@ -91,18 +83,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{}, Inputs: []input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input.Input{{Token: 1}}, prompt: []input{{token: 1}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0}, best: expected{result: 0, len: 0},
}, },
@@ -111,18 +103,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}}, Inputs: []input{{token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []input{{token: 1}, {token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []input{{token: 1}, {token: 2}},
longest: expected{result: 1, len: 2}, longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@@ -131,18 +123,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []input{{token: 1}, {token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []input.Input{{Token: 2}}, prompt: []input{{token: 2}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@@ -152,19 +144,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []input{{token: 1}, {token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{}, Inputs: []input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}, },
}, },
prompt: []input.Input{{Token: 1}}, prompt: []input{{token: 1}},
longest: expected{result: 0, len: 1}, longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1}, best: expected{result: 1, len: 1},
}, },
@@ -173,18 +165,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}}, Inputs: []input{{token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []input{{token: 1}, {token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 2}, {Token: 3}}, prompt: []input{{token: 2}, {token: 3}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@@ -193,18 +185,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}}, Inputs: []input{{token: 1}, {token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []input.Input{{Token: 1}}, Inputs: []input{{token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []input.Input{{Token: 1}, {Token: 2}}, prompt: []input{{token: 1}, {token: 2}},
longest: expected{result: 1, len: 1}, longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },

View File

@@ -1,12 +1,13 @@
package ollamarunner package ollamarunner
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"hash/maphash" "image"
"log" "log"
"log/slog" "log/slog"
"net" "net"
@@ -26,26 +27,28 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample" "github.com/ollama/ollama/sample"
_ "github.com/ollama/ollama/model/models" _ "github.com/ollama/ollama/model/models"
) )
type Sequence struct { // input is an element of the prompt to process, either a token or an image
// ctx for allocating tensors that last the lifetime of the sequence, such as type input struct {
// multimodal embeddings token int32
ctx ml.Context
image image.Image
}
type Sequence struct {
// batch index // batch index
iBatch int iBatch int
// prompt inputs left to evaluate // prompt inputs left to evaluate
inputs []input.Input inputs []input
// inputs that have been added to a batch but not yet submitted to Forward // inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input.Input pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
@@ -98,9 +101,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
s.ready.Wait() s.ready.Wait()
startTime := time.Now() startTime := time.Now()
ctx := s.model.Backend().NewContext()
inputs, err := s.inputs(ctx, prompt, images) inputs, err := s.inputs(prompt, images)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err) return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 { } else if len(inputs) == 0 {
@@ -126,7 +128,6 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// TODO(jessegross): Ingest cached history for grammar // TODO(jessegross): Ingest cached history for grammar
return &Sequence{ return &Sequence{
ctx: ctx,
inputs: inputs, inputs: inputs,
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
@@ -145,31 +146,28 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input.Input var inputs []input
var parts []string var parts []string
var matches [][]string var matches [][]string
multimodalProcessor, visionModel := s.model.(model.MultimodalProcessor) // TODO(jessegross): This can sometimes trigger for matching text in the
// user's prompt. We previously tried to avoid it by only looking for images
// on image models. We don't have a clear indication now but it would be better
// to properly escape it in any case.
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
if visionModel {
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
postTokenize := false
for i, part := range parts { for i, part := range parts {
// text - tokenize // text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) tokens, err := s.model.(model.TextProcessor).Encode(part)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, t := range tokens { for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t}) inputs = append(inputs, input{token: t})
} }
// image - decode and store // image - decode and store
@@ -188,25 +186,12 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in
return nil, fmt.Errorf("invalid image index: %d", n) return nil, fmt.Errorf("invalid image index: %d", n)
} }
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data))
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.multimodalHash.Reset() inputs = append(inputs, input{image: image})
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
if visionModel && postTokenize {
var err error
inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
if err != nil {
return nil, err
} }
} }
@@ -251,15 +236,8 @@ type Server struct {
// KV cache // KV cache
cache *InputCache cache *InputCache
// multimodalHash generates hashes for comparing equality // next sequence for prompt processing to avoid starvation
// of non-text data nextSeq int
multimodalHash maphash.Hash
// vocab is a llama.cpp vocab required for gammar-based
// constrained generation (json mode, structured outputs)
// TODO: this is temporary until Ollama sampling supports
// constrained generation
vocab *sample.Vocab
} }
func (s *Server) allNil() bool { func (s *Server) allNil() bool {
@@ -305,7 +283,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
close(seq.responses) close(seq.responses)
close(seq.embedding) close(seq.embedding)
seq.cache.InUse = false seq.cache.InUse = false
seq.ctx.Close()
s.seqs[seqIndex] = nil s.seqs[seqIndex] = nil
s.seqsSem.Release(1) s.seqsSem.Release(1)
} }
@@ -333,25 +310,30 @@ func (s *Server) processBatch() error {
} }
defer s.mu.Unlock() defer s.mu.Unlock()
var options input.Options var options model.Options
imgSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
for i, seq := range s.seqs {
if seq == nil { if seq == nil {
continue continue
} }
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, "limit") s.removeSequence(seqIdx, "limit")
continue continue
} }
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input.Input{} seq.cache.Inputs = []input{}
} }
for j, inp := range seq.inputs { for i, input := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 { if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
@@ -363,23 +345,37 @@ func (s *Server) processBatch() error {
} }
} }
if j >= s.batchSize { if i >= s.batchSize {
break break
} }
options.Inputs = append(options.Inputs, inp.Token) // TODO(jessegross): Image inputs need to be rethought - it's
if inp.Multimodal != nil { // it doesn't work well for different types of models or multiple sequences
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) if input.image != nil {
if len(seq.pendingInputs) != len(options.Images) {
break
}
if imgSeq != seqIdx && imgSeq != -1 {
s.nextSeq = seqIdx
break
}
imgSeq = seqIdx
options.Images = append(options.Images, input.image)
seq.pendingInputs = append(seq.pendingInputs, input)
continue
} }
options.Inputs = append(options.Inputs, input.token)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id) options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(options.Outputs) seq.iBatch = len(options.Outputs)
if j+1 == len(seq.inputs) { if i+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
} }
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, input)
} }
seq.inputs = seq.inputs[len(seq.pendingInputs):] seq.inputs = seq.inputs[len(seq.pendingInputs):]
@@ -407,7 +403,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache // After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input.Input{} seq.pendingInputs = []input{}
} }
// don't sample prompt processing // don't sample prompt processing
@@ -426,7 +422,6 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return // if done processing the prompt, generate an embedding and return
if seq.embeddingOnly { if seq.embeddingOnly {
// TODO(jessegross): Embedding support // TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, "") s.removeSequence(i, "")
continue continue
} }
@@ -454,7 +449,7 @@ func (s *Server) processBatch() error {
return err return err
} }
seq.inputs = []input.Input{{Token: token}} seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")
@@ -580,30 +575,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
var grammar *sample.Grammar
var err error
if req.Grammar != "" {
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return
}
}
sampler := sample.NewSampler(
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
grammar,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict, numPredict: req.NumPredict,
stop: req.Stop, stop: req.Stop,
numKeep: int32(req.NumKeep), numKeep: int32(req.NumKeep),
sampler: sampler, sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
embedding: false, embedding: false,
}) })
if err != nil { if err != nil {
@@ -691,6 +667,65 @@ type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"` Embedding []float32 `json:"embedding"`
} }
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
var req EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
slog.Debug("embedding request", "content", req.Content)
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
embedding := <-seq.embedding
if err := json.NewEncoder(w).Encode(&EmbeddingResponse{
Embedding: embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
type HealthResponse struct { type HealthResponse struct {
Status string `json:"status"` Status string `json:"status"`
Progress float32 `json:"progress"` Progress float32 `json:"progress"`
@@ -751,7 +786,7 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
s.vocab = sample.NewVocab(mpath) slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
// TODO(jessegross): LoRA loading // TODO(jessegross): LoRA loading
if lpath.String() != "" { if lpath.String() != "" {
@@ -783,7 +818,7 @@ func Execute(args []string) error {
batchSize := fs.Int("batch-size", 512, "Batch size") batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU") mainGPU := fs.Int("main-gpu", 0, "Main GPU")
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") _ = fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on") port := fs.Int("port", 8080, "Port to expose the server on")
@@ -828,6 +863,7 @@ func Execute(args []string) error {
} }
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented:
// flash-attn
// no-mmap // no-mmap
// mlock // mlock
@@ -842,11 +878,10 @@ func Execute(args []string) error {
} }
params := ml.BackendParams{ params := ml.BackendParams{
NumThreads: *threads, NumThreads: *threads,
NumGPULayers: *numGPULayers, NumGPULayers: *numGPULayers,
MainGPU: *mainGPU, MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats, TensorSplit: tensorSplitFloats,
FlashAttention: *flashAttention,
} }
server.ready.Add(1) server.ready.Add(1)
@@ -868,13 +903,9 @@ func Execute(args []string) error {
defer listener.Close() defer listener.Close()
mux := http.NewServeMux() mux := http.NewServeMux()
// TODO: support embeddings mux.HandleFunc("/embedding", server.embeddings)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/completion", server.completion)
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) mux.HandleFunc("/health", server.health)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)
httpServer := http.Server{ httpServer := http.Server{
Handler: mux, Handler: mux,

View File

@@ -3,223 +3,118 @@ package sample
import ( import (
"errors" "errors"
"math" "math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/llama" "golang.org/x/exp/rand"
"gonum.org/v1/gonum/stat/sampleuv"
) )
// token represents information about a single token during sampling type Sampler interface {
type token struct { Sample([]float32) (int32, error)
id int32 // The token's unique identifier
value float32 // The raw logit or probability from the model
} }
type Sampler struct { type weighted struct {
rng *rand.Rand src rand.Source
topK int transforms []Transform
topP float32
minP float32
temperature float32
grammar *Grammar
} }
func (s *Sampler) Sample(logits []float32) (int32, error) { // TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279
tokens := make([]token, len(logits)) func Weighted(seed *uint64, transforms ...Transform) Sampler {
var src rand.Source
if seed != nil {
src = rand.NewSource(*seed)
}
return weighted{src: src, transforms: transforms}
}
func (s weighted) Sample(logits []float32) (int32, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
for _, t := range s.transforms {
logits64 = t.Apply(logits64)
}
logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits))
for i, logit := range logits64 {
if !math.IsInf(logit, -1) {
logitsCopy = append(logitsCopy, logit)
indices = append(indices, i)
}
}
if len(logitsCopy) == 0 {
return -1, errors.New("no valid logits found for weighed sampling")
}
probs := softmax(logitsCopy)
w := sampleuv.NewWeighted(probs, s.src)
if idx, ok := w.Take(); ok {
return int32(indices[idx]), nil
}
return -1, errors.New("weighted sampler failed, no valid token found")
}
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
// Sample returns the index of the maximum value in logits.
func (s greedy) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("no logits provided for greedy sampling")
}
maxIdx := 0
for i := range logits { for i := range logits {
tokens[i].id = int32(i) if logits[i] > logits[maxIdx] {
tokens[i].value = logits[i] maxIdx = i
}
t, err := s.sample(tokens)
if err != nil {
return -1, err
}
if s.grammar != nil {
// optimization: first check if the max logit is accepted by the grammar
// if the max logit is rejected, apply the grammar to all logits (slower)
top := []token{t}
s.grammar.Apply(top)
if !math.IsInf(float64(top[0].value), -1) {
s.grammar.Accept(top[0].id)
return top[0].id, nil
}
// since .sample has side effects of modifying the tokens
// we need to reset them before applying the grammar and
// sampling again
for i := range logits {
tokens[i].id = int32(i)
tokens[i].value = logits[i]
}
s.grammar.Apply(tokens)
t, err = s.sample(tokens)
if err != nil {
return -1, err
}
s.grammar.Accept(t.id)
}
return t.id, nil
}
// greedy returns the highest probability token from the tokens
func greedy(tokens []token) token {
max := tokens[0]
for i := 1; i < len(tokens); i++ {
if tokens[i].value > max.value {
max = tokens[i]
} }
} }
return max return int32(maxIdx), nil
}
// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
if s.temperature == 0 {
return greedy(tokens), nil
}
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if len(tokens) == 0 {
return token{}, errors.New("no tokens to sample from")
}
var r float32
if s.rng != nil {
r = s.rng.Float32()
} else {
r = rand.Float32()
}
// Calculate cumulative sum of probabilities
var sum float32
for i := range tokens {
sum += tokens[i].value
tokens[i].value = sum
}
r *= tokens[len(tokens)-1].value
idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int {
if token.value < target {
return -1
}
return 1
})
return tokens[idx], nil
} }
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
var rng *rand.Rand if temperature == 0 {
if seed != -1 { return Greedy(), nil
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
}
if temperature < 0.0 {
temperature = 0.0
} }
if topP < 0.0 { if temperature < 0 || temperature > 2 {
topP = 0.0 return nil, errors.New("temperature must be between 0 and 2")
}
if topP >= 1.0 {
topP = 1.0
} }
if minP < 0.0 { transforms := []Transform{Temperature(temperature)}
minP = 0.0
}
if minP >= 1.0 {
minP = 1.0
}
return Sampler{ if topK != 0 {
rng: rng, if topK <= 0 {
topK: topK, return nil, errors.New("topK must be greater than 0")
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
}
}
type Grammar struct {
vocab *Vocab
grammar string
sampler *llama.Sampler
}
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
v, err := vocab.Load()
if err != nil {
return nil, err
}
return &Grammar{
vocab: vocab,
grammar: grammar,
sampler: llama.NewGrammarSampler(v, grammar),
}, nil
}
func (g *Grammar) Apply(tokens []token) {
tds := make([]llama.TokenData, len(tokens))
for i, token := range tokens {
tds[i].Id = token.id
tds[i].Logit = token.value
}
g.sampler.Apply(tds)
for i := range tokens {
tokens[i].value = tds[i].Logit
}
}
func (g *Grammar) Accept(token int32) {
g.sampler.Accept(token)
}
type Vocab struct {
once sync.Once
vocab *llama.Vocab
err error
path string
}
func NewVocab(path string) *Vocab {
return &Vocab{path: path}
}
// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
v.once.Do(func() {
vocab, err := llama.LoadVocabFromFile(v.path)
if err != nil {
v.err = err
return
} }
v.vocab = vocab transforms = append(transforms, TopK(topK))
}) }
return v.vocab, v.err
if topP != 0 {
if topP < 0 || topP >= 1 {
return nil, errors.New("topP must be between 0 and 1")
}
transforms = append(transforms, TopP(topP))
}
if minP != 0 {
if minP < 0 || minP >= 1 {
return nil, errors.New("minP must be between 0 and 1")
}
transforms = append(transforms, MinP(minP))
}
if seed >= 0 {
seed64 := uint64(seed)
return Weighted(&seed64, transforms...), nil
}
return Weighted(nil, transforms...), nil
} }

View File

@@ -1,92 +0,0 @@
package sample
import (
"fmt"
"math/rand"
"testing"
)
func BenchmarkWeightedSampler(b *testing.B) {
sizes := []int{10, 100, 1000, 10000}
for _, size := range sizes {
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
configs := []struct {
name string
temperature float32
topK int
topP float32
minP float32
seed int
}{
{"Greedy", 0, -1, 0, 0, -1},
{"Temperature", 0.8, -1, 0, 0, -1},
{"TopK", 0.8, 50, 0, 0, -1},
{"TopP", 0.8, -1, 0.9, 0, -1},
{"MinP", 0.8, -1, 0, 0.05, -1},
{"WithSeed", 0.8, 50, 0, 0, 42},
}
// Fixed size for common vocab size
size := 128000
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler.Sample(logits)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
func BenchmarkGreedySampler(b *testing.B) {
sizes := []int{10, 100, 1000, 10000, 100000}
for _, size := range sizes {
b.Run(fmt.Sprintf("Size %d", size), func(b *testing.B) {
logits := make([]float32, size)
for i := range logits {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
}
})
}
}

View File

@@ -1,14 +1,15 @@
package sample package sample
import ( import (
"math"
"math/rand/v2" "math/rand/v2"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
func TestWeighted(t *testing.T) { func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10} got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
sampler := NewSampler(0, 0, 0, 0, 0, nil)
got, err := sampler.Sample(logits)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@@ -18,26 +19,194 @@ func TestWeighted(t *testing.T) {
t.Errorf("index mismatch: want %d, got %d", want, got) t.Errorf("index mismatch: want %d, got %d", want, got)
} }
logits = []float32{-100, -10, 0, 10} got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
sampler = NewSampler(0, 0, 0, 0, 0, nil) if err == nil {
got, err = sampler.Sample(logits) t.Error("expected error for no valid tokens, got index", got)
}
seed := uint64(42)
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
want = int32(3) // Should pick highest probability with this r value // With seed 42, we expect a consistent sample
want = int32(3) // This will be deterministic due to the seed
if want != got { if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got) t.Errorf("index mismatch: want %d, got %d", want, got)
} }
} }
func BenchmarkSample(b *testing.B) { type testTransform struct {
samplers := map[string]Sampler{ id int
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy callOrder *[]int
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), }
func (ts *testTransform) Apply(logits []float64) []float64 {
if ts.callOrder != nil {
*ts.callOrder = append(*ts.callOrder, ts.id)
}
return logits
}
func TestSample(t *testing.T) {
input := []float32{1, 2, 3, 4}
var callOrder []int
mock1 := &testTransform{
id: 1,
callOrder: &callOrder,
}
mock2 := &testTransform{
id: 2,
callOrder: &callOrder,
}
mock3 := &testTransform{
id: 3,
callOrder: &callOrder,
}
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
if err != nil {
t.Error(err)
return
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
}
func TestNewSampler(t *testing.T) {
tests := []struct {
name string
temperature float32
topK int
topP float32
minP float32
seed int
wantErr bool
}{
{
name: "no transforms",
// temperature is 0, so greedy should be used
wantErr: false,
},
{
name: "temperature",
temperature: 0.5,
wantErr: false,
},
{
name: "invalid temperature negative",
temperature: -1,
wantErr: true,
},
{
name: "invalid temperature too high",
temperature: 2.1,
wantErr: true,
},
{
name: "top k",
topK: 10,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid top k negative",
topK: -1,
temperature: 0.8,
wantErr: true,
},
{
name: "top p",
topP: 0.9,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid top p negative",
topP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid top p one",
topP: 1.0,
temperature: 0.8,
wantErr: true,
},
{
name: "min p",
minP: 0.2,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid min p negative",
minP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid min p one",
minP: 1.0,
temperature: 0.8,
wantErr: true,
},
{
name: "default values",
temperature: 0.8,
topK: 40,
topP: 0.9,
minP: 0.0,
seed: 0,
wantErr: false,
},
{
name: "all zeroes",
temperature: 0.0,
topK: 0,
topP: 0.0,
minP: 0.0,
seed: 0,
wantErr: false, // all zeroes means no transforms
},
{
name: "all transforms",
temperature: 0.8,
topK: 50,
topP: 0.95,
minP: 0.1,
seed: 42,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
if (err != nil) != tt.wantErr {
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func BenchmarkSample(b *testing.B) {
transforms := []Transform{
Temperature(0.5),
TopK(10),
TopP(0.9),
MinP(0.2),
}
samplers := map[string]Sampler{
"Greedy": Greedy(),
"Weighted": Weighted(nil, transforms...),
} }
// Generate random logits for benchmarking
logits := make([]float32, 1<<16) logits := make([]float32, 1<<16)
for i := range logits { for i := range logits {
logits[i] = rand.Float32() logits[i] = rand.Float32()
@@ -46,9 +215,9 @@ func BenchmarkSample(b *testing.B) {
for name, s := range samplers { for name, s := range samplers {
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for b.Loop() { for range b.N {
if _, err := s.Sample(logits); err != nil { if _, err := s.Sample(logits); err != nil {
b.Fatalf("error sampling: %v", err) b.Error(err)
} }
} }
}) })

View File

@@ -1,145 +1,120 @@
package sample package sample
import ( import (
"container/heap" "cmp"
"math" "math"
"slices" "slices"
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
) )
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements type Transform interface {
type tokenHeap []token Apply([]float64) []float64
func (h tokenHeap) Len() int { return len(h) }
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
func (h tokenHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *tokenHeap) Push(x any) {
*h = append(*h, x.(token))
} }
func (h *tokenHeap) Pop() any { // TODO(parthsareen): potentially cache softmax values
old := *h func softmax(logits []float64) []float64 {
n := len(old) var sum float64
x := old[n-1] probs := make([]float64, len(logits))
*h = old[0 : n-1] for i, v := range logits {
return x probs[i] = math.Exp(v)
} sum += probs[i]
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
} }
return ts
for i := range probs {
probs[i] /= sum
}
return probs
} }
// softmax applies normalization to the logits type Temperature float64
func softmax(ts []token) []token {
// Find max logit for numerical stability func (t Temperature) Apply(logits []float64) []float64 {
maxLogit := float32(math.Inf(-1)) temp := math.Max(float64(t), 1e-7)
for _, t := range ts {
if t.value > maxLogit { // subtracting max logit to avoid under/overflow
maxLogit = t.value maxLogit := slices.Max(logits)
for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp
}
return logits
}
type logitMap struct {
index int
logit float64
}
type TopK int
// TODO(parthsareen): avoid having to check all logits after this transform
func (k TopK) Apply(logits []float64) []float64 {
if int(k) >= len(logits) {
return logits
}
q := pq.NewWith(func(a, b logitMap) int {
return -cmp.Compare(a.logit, b.logit)
})
for i, logit := range logits {
q.Enqueue(logitMap{index: i, logit: logit})
}
validLogits := make(map[int]float64)
for range k {
logitMap, _ := q.Dequeue()
validLogits[logitMap.index] = logitMap.logit
}
for i := range logits {
if _, ok := validLogits[i]; !ok {
logits[i] = math.Inf(-1)
} }
} }
// Compute exp(x - max) return logits
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
sum += ts[i].value
}
// exp(x - max) / sum(exp(x - max))
for i := range ts {
ts[i].value /= sum
}
return ts
} }
// topK limits the number of tokens considered to the k highest logits type TopP float64
func topK(ts []token, k int) []token {
if k >= len(ts) || k <= 0 { func (p TopP) Apply(logits []float64) []float64 {
slices.SortFunc(ts, func(a, b token) int { probs := softmax(logits)
switch { indices := make([]int, len(probs))
case a.value < b.value: for i := range indices {
return 1 indices[i] = i
case a.value > b.value: }
return -1
default: // sort in descending order
return 0 slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(probs[j], probs[i])
})
var sum float64
for i, idx := range indices {
sum += probs[idx]
if sum > float64(p) {
for _, idx := range indices[i+1:] {
logits[idx] = math.Inf(-1)
} }
}) break
return ts
}
// Initialize min-heap with first k elements
h := make(tokenHeap, k)
copy(h, ts[:k])
heap.Init(&h)
// Process remaining elements
for i := k; i < len(ts); i++ {
if ts[i].value > h[0].value {
heap.Pop(&h)
heap.Push(&h, ts[i])
} }
} }
return logits
// Convert heap to sorted slice in descending order
result := make([]token, len(h))
for i := k - 1; i >= 0; i-- {
result[i] = heap.Pop(&h).(token)
}
return result
} }
// topP limits tokens to those with cumulative probability p type MinP float64
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
// Find cutoff index where cumulative sum exceeds p func (p MinP) Apply(logits []float64) []float64 {
var sum float32 probs := softmax(logits)
for i, t := range ts { threshold := slices.Max(probs) * float64(p)
sum += t.value
if sum > float32(p) { for i, prob := range probs {
ts = ts[:i+1] if prob < threshold {
return ts logits[i] = math.Inf(-1)
} }
} }
return ts return logits
}
// minP limits tokens to those with cumulative probability p
func minP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
maxProb := float32(math.Inf(-1))
for _, token := range ts {
if token.value > maxProb {
maxProb = token.value
}
}
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts
} }

View File

@@ -4,258 +4,77 @@ import (
"math" "math"
"math/rand/v2" "math/rand/v2"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
// Helper to convert float32 slice to logit slice
func toTokens(values []float32) []token {
tokens := make([]token, len(values))
for i, v := range values {
tokens[i] = token{
id: int32(i),
value: v,
}
}
return tokens
}
// Helper to compare logit slices
func compareLogits(t *testing.T, name string, want []float32, got []token) {
t.Helper()
if len(want) != len(got) {
t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got))
return
}
for i := range want {
if math.Abs(float64(got[i].value-want[i])) > 1e-6 {
t.Errorf("%s: index %d: want %f, got %f", name, i, want[i], got[i].value)
}
}
}
func TestTemperature(t *testing.T) { func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0} got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
got := temperature(toTokens(input), 0.5) want := []float64{-4, -10, 0, -14, -6, -12, -8}
want := []float32{2.0, 8.0, -4.0, 0.0} if diff := cmp.Diff(want, got); diff != "" {
compareLogits(t, "temperature(0.5)", want, got) t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
got = temperature(toTokens(input), 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, got)
got = temperature(toTokens(input), 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, got)
} }
func TestSoftmax(t *testing.T) { func TestSoftmax(t *testing.T) {
tests := []struct { got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
name string
input []float32
expected []float32
}{
{
name: "correctness softmax",
input: []float32{1, -2, 3, 0},
expected: []float32{0.113550, 0.005653, 0.839024, 0.041773},
},
{
name: "normal distribution",
input: []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367},
},
{
name: "single value",
input: []float32{1.0},
},
{
name: "identical values",
input: []float32{0.9, 0.9, 0.9},
},
{
name: "large values",
input: []float32{1000.0, 2000.0, 3000.0},
},
{
name: "small values",
input: []float32{1e-6, 2e-6, 3e-6},
},
{
name: "negative values",
input: []float32{-1.0, -2.0, -3.0},
},
{
name: "mixed values",
input: []float32{-100.0, 0.0, 100.0},
},
}
for _, tt := range tests { want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
t.Run(tt.name, func(t *testing.T) { if diff := cmp.Diff(want, got); diff != "" {
got := softmax(toTokens(tt.input)) t.Errorf("probs mismatch (-want +got):\n%s", diff)
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, got)
return
}
// Check probabilities sum to 1
var sum float32
for _, token := range got {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
}
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
})
} }
} }
func TestTopK(t *testing.T) { func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
// Test k=5 if diff := cmp.Diff(want, got); diff != "" {
got := topK(toTokens(input), 5) t.Errorf("logits mismatch (-want +got):\n%s", diff)
if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
}
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, got)
got = topK(toTokens(input), 20)
if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
} }
// Test k=-1 got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
// Test k=0 want = []float64{-3, -2, -1, 0, 1, 2, 4}
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} if diff := cmp.Diff(want, got); diff != "" {
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} t.Errorf("logits mismatch (-want +got):\n%s", diff)
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
} }
compareLogits(t, "topK(-1)", want, got)
} }
func TestTopP(t *testing.T) { func TestTopP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4} got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
tokens := toTokens(input) want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
if diff := cmp.Diff(want, got); diff != "" {
// First apply temperature and softmax to get probabilities t.Errorf("logits mismatch (-want +got):\n%s", diff)
tokens = softmax(tokens)
tokens = topK(tokens, 20)
// Then apply topP
got := topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
} }
} }
func TestMinP(t *testing.T) { func TestMinP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3} got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
tokens := toTokens(input) want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
if diff := cmp.Diff(want, got); diff != "" {
// First apply temperature and softmax t.Errorf("logits mismatch (-want +got):\n%s", diff)
tokens = softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
} }
} }
func TestSortLogits(t *testing.T) { func BenchmarkTransform(b *testing.B) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} transforms := map[string]Transform{
tokens := toTokens(input) "Temperature": Temperature(0.5),
"TopK": TopK(10),
tokens = topK(tokens, 20) "TopP": TopP(0.9),
"MinP": MinP(0.2),
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
i, tokens[i].value, tokens[i-1].value)
}
} }
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} logits := make([]float64, 1<<16)
compareLogits(t, "sortLogits", want, tokens) for i := range logits {
} logits[i] = rand.Float64()
func BenchmarkTransforms(b *testing.B) {
// Generate random logits
tokens := make([]token, 1<<16)
for i := range tokens {
tokens[i] = token{
id: int32(i),
value: rand.Float32(),
}
} }
tokensCopy := make([]token, len(tokens)) for name, transform := range transforms {
b.Run(name, func(b *testing.B) {
b.Run("Temperature", func(b *testing.B) { b.ResetTimer()
b.ResetTimer() for range b.N {
for b.Loop() { transform.Apply(logits)
copy(tokensCopy, tokens) }
temperature(tokensCopy, 0.5) })
} }
})
b.Run("Softmax", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
softmax(tokensCopy)
}
})
b.Run("TopK", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 10)
}
})
b.Run("TopP", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topP(tokensCopy, 0.9)
}
})
b.Run("MinP", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
minP(tokensCopy, 0.2)
}
})
b.Run("SortTokens", func(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 200000)
}
})
} }

View File

@@ -80,14 +80,13 @@ function checkEnv() {
function buildOllama() { function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") { if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0 New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR & cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset CPU --config Release --parallel $script:JOBS & cmake --build --preset CPU --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip & cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -102,7 +101,7 @@ function buildOllama() {
# to avoid 2022 (or newer) from being used as the default # to avoid 2022 (or newer) from being used as the default
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR & cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS & cmake --build --preset "CUDA 11" --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip & cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -113,7 +112,7 @@ function buildOllama() {
write-host "Building CUDA v12 backend libraries" write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR & cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS & cmake --build --preset "CUDA 12" --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip & cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -132,7 +131,7 @@ function buildOllama() {
$env:HIPCXX="" $env:HIPCXX=""
$env:HIP_PLATFORM="" $env:HIP_PLATFORM=""
$env:CMAKE_PREFIX_PATH="" $env:CMAKE_PREFIX_PATH=""
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS & cmake --build --preset "ROCm" --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip & cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}

View File

@@ -77,12 +77,11 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then
fi fi
status "Installing ollama to $OLLAMA_INSTALL_DIR" status "Installing ollama to $OLLAMA_INSTALL_DIR"
$SUDO install -o0 -g0 -m755 -d $BINDIR $SUDO install -o0 -g0 -m755 -d $BINDIR
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama" $SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR"
status "Downloading Linux ${ARCH} bundle" status "Downloading Linux ${ARCH} bundle"
curl --fail --show-error --location --progress-bar \ curl --fail --show-error --location --progress-bar \
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \ "https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR" $SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
status "Making ollama accessible in the PATH in $BINDIR" status "Making ollama accessible in the PATH in $BINDIR"
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama" $SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"

View File

@@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
@@ -35,7 +34,6 @@ var (
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format") errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
errUnknownType = errors.New("unknown type") errUnknownType = errors.New("unknown type")
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified") errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
errFilePath = errors.New("file path must be relative")
) )
func (s *Server) CreateHandler(c *gin.Context) { func (s *Server) CreateHandler(c *gin.Context) {
@@ -48,13 +46,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
return return
} }
for v := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
return
}
}
name := model.ParseName(cmp.Or(r.Model, r.Name)) name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() { if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
@@ -113,7 +104,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
if r.Adapters != nil { if r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil { if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
if errors.Is(err, badReq) { if errors.Is(err, badReq) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return return
@@ -230,22 +221,8 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err return nil, err
} }
defer os.RemoveAll(tmpDir) defer os.RemoveAll(tmpDir)
// Set up a root to validate paths
root, err := os.OpenRoot(tmpDir)
if err != nil {
return nil, err
}
defer root.Close()
for fp, digest := range files { for fp, digest := range files {
if !fs.ValidPath(fp) {
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
}
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest) blobPath, err := GetBlobsPath(digest)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -293,7 +270,6 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer bin.Close()
f, _, err := ggml.Decode(bin, 0) f, _, err := ggml.Decode(bin, 0)
if err != nil { if err != nil {

View File

@@ -1,106 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestConvertFromSafetensors(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}
return l.Digest
}
// Create a safetensors compatible file with empty JSON content
var buf bytes.Buffer
headerSize := int64(len("{}"))
binary.Write(&buf, binary.LittleEndian, headerSize)
buf.WriteString("{}")
model := makeTemp(buf.String())
config := makeTemp(`{
"architectures": ["LlamaForCausalLM"],
"vocab_size": 32000
}`)
tokenizer := makeTemp(`{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
]
}`)
tests := []struct {
name string
filePath string
wantErr error
}{
// Invalid
{
name: "InvalidRelativePathShallow",
filePath: filepath.Join("..", "file.safetensors"),
wantErr: errFilePath,
},
{
name: "InvalidRelativePathDeep",
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
wantErr: errFilePath,
},
{
name: "InvalidNestedPath",
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
wantErr: errFilePath,
},
{
name: "AbsolutePathOutsideRoot",
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
wantErr: errFilePath, // Should fail since it's outside tmpDir
},
{
name: "ValidRelativePath",
filePath: "model.safetensors",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create the minimum required file map for convertFromSafetensors
files := map[string]string{
tt.filePath: model,
"config.json": config,
"tokenizer.json": tokenizer,
}
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
if (tt.wantErr == nil && err != nil) ||
(tt.wantErr != nil && err == nil) ||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -27,7 +27,6 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -45,9 +44,9 @@ import (
// Errors // Errors
var ( var (
// ErrModelNotFound is returned when a manifest is not found in the // ErrManifestNotFound is returned when a manifest is not found in the
// cache or registry. // cache or registry.
ErrModelNotFound = errors.New("model not found") ErrManifestNotFound = errors.New("manifest not found")
// ErrManifestInvalid is returned when a manifest found in a local or // ErrManifestInvalid is returned when a manifest found in a local or
// remote cache is invalid. // remote cache is invalid.
@@ -74,22 +73,19 @@ const (
DefaultMaxChunkSize = 8 << 20 DefaultMaxChunkSize = 8 << 20
) )
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) { // DefaultCache returns a new disk cache for storing models. If the
// OLLAMA_MODELS environment variable is set, it uses that directory;
// otherwise, it uses $HOME/.ollama/models.
func DefaultCache() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS") dir := os.Getenv("OLLAMA_MODELS")
if dir == "" { if dir == "" {
home, _ := os.UserHomeDir() home, err := os.UserHomeDir()
home = cmp.Or(home, ".") if err != nil {
return nil, err
}
dir = filepath.Join(home, ".ollama", "models") dir = filepath.Join(home, ".ollama", "models")
} }
return blob.Open(dir) return blob.Open(dir)
})
// DefaultCache returns the default cache used by the registry. It is
// configured from the OLLAMA_MODELS environment variable, or defaults to
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
// it uses the current working directory.
func DefaultCache() (*blob.DiskCache, error) {
return defaultCache()
} }
// Error is the standard error returned by Ollama APIs. It can represent a // Error is the standard error returned by Ollama APIs. It can represent a
@@ -114,18 +110,7 @@ type Error struct {
} }
func (e *Error) Error() string { func (e *Error) Error() string {
var b strings.Builder return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
}
if e.Message != "" {
b.WriteString(": ")
b.WriteString(e.Message)
}
return b.String()
} }
func (e *Error) LogValue() slog.Value { func (e *Error) LogValue() slog.Value {
@@ -183,10 +168,6 @@ func CompleteName(name string) string {
// Registry is a client for performing push and pull operations against an // Registry is a client for performing push and pull operations against an
// Ollama registry. // Ollama registry.
type Registry struct { type Registry struct {
// Cache is the cache used to store models. If nil, [DefaultCache] is
// used.
Cache *blob.DiskCache
// UserAgent is the User-Agent header to send with requests to the // UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient. // registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string UserAgent string
@@ -225,28 +206,18 @@ type Registry struct {
// It is only used when a layer is larger than [MaxChunkingThreshold]. // It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64 MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names // Mask, if set, is the name used to convert non-fully qualified
// to fully qualified names. If empty, [DefaultMask] is used. // names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used.
Mask string Mask string
} }
func (r *Registry) cache() (*blob.DiskCache, error) { func (r *Registry) completeName(name string) names.Name {
if r.Cache != nil {
return r.Cache, nil
}
return defaultCache()
}
func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask mask := defaultMask
if r.Mask != "" { if r.Mask != "" {
mask = names.Parse(r.Mask) mask = names.Parse(r.Mask)
} }
n := names.Merge(names.Parse(name), mask) return names.Merge(names.Parse(name), mask)
if !n.IsFullyQualified() {
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
return n, nil
} }
// DefaultRegistry returns a new Registry configured from the environment. The // DefaultRegistry returns a new Registry configured from the environment. The
@@ -307,17 +278,12 @@ type PushParams struct {
} }
// Push pushes the model with the name in the cache to the remote registry. // Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error { func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
if p == nil { if p == nil {
p = &PushParams{} p = &PushParams{}
} }
c, err := r.cache() m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
if err != nil {
return err
}
m, err := r.ResolveLocal(cmp.Or(p.From, name))
if err != nil { if err != nil {
return err return err
} }
@@ -340,7 +306,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
t := traceFromContext(ctx) t := traceFromContext(ctx)
scheme, n, _, err := r.parseNameExtended(name) scheme, n, _, err := parseName(name, r.Mask)
if err != nil { if err != nil {
// This should never happen since ResolveLocal should have // This should never happen since ResolveLocal should have
// already validated the name. // already validated the name.
@@ -366,7 +332,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
n.Model(), n.Model(),
l.Digest, l.Digest,
) )
res, err := r.send(ctx, "POST", startURL, nil) res, err := r.doOK(ctx, "POST", startURL, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -390,7 +356,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
} }
req.ContentLength = l.Size req.ContentLength = l.Size
res, err = sendRequest(r.client(), req) res, err = doOK(r.client(), req)
if err == nil { if err == nil {
res.Body.Close() res.Body.Close()
} }
@@ -410,7 +376,7 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
n.Model(), n.Model(),
n.Tag(), n.Tag(),
) )
res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data)) res, err := r.doOK(ctx, "PUT", path, bytes.NewReader(m.Data))
if err == nil { if err == nil {
res.Body.Close() res.Body.Close()
} }
@@ -433,8 +399,8 @@ func canRetry(err error) bool {
// chunks of the specified size, and then reassembled and verified. This is // chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly // typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image". // utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, name string) error { func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := r.parseNameExtended(name) scheme, n, _, err := parseName(name, r.Mask)
if err != nil { if err != nil {
return err return err
} }
@@ -447,11 +413,6 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return fmt.Errorf("%w: no layers", ErrManifestInvalid) return fmt.Errorf("%w: no layers", ErrManifestInvalid)
} }
c, err := r.cache()
if err != nil {
return err
}
exists := func(l *Layer) bool { exists := func(l *Layer) bool {
info, err := c.Get(l.Digest) info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size return err == nil && info.Size == l.Size
@@ -459,15 +420,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
t := traceFromContext(ctx) t := traceFromContext(ctx)
g, ctx := errgroup.WithContext(ctx) var g errgroup.Group
g.SetLimit(r.maxStreams()) g.SetLimit(r.maxStreams())
layers := m.Layers for _, l := range m.Layers {
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
}
for _, l := range layers {
if exists(l) { if exists(l) {
t.update(l, l.Size, ErrCached) t.update(l, l.Size, ErrCached)
continue continue
@@ -484,9 +440,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if l.Size <= r.maxChunkingThreshold() { if l.Size <= r.maxChunkingThreshold() {
g.Go(func() error { g.Go(func() error {
// TODO(bmizerany): retry/backoff like below in res, err := doOK(r.client(), req)
// the chunking case
res, err := sendRequest(r.client(), req)
if err != nil { if err != nil {
return err return err
} }
@@ -512,21 +466,19 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// fire an initial request to get the final URL and // fire an initial request to get the final URL and
// then use that URL for the chunk requests. // then use that URL for the chunk requests.
req.Header.Set("Range", "bytes=0-0") req.Header.Set("Range", "bytes=0-0")
res, err := sendRequest(r.client(), req) res, err := doOK(r.client(), req)
if err != nil { if err != nil {
return err return err
} }
res.Body.Close() res.Body.Close()
req = res.Request.WithContext(req.Context()) req = res.Request.WithContext(req.Context())
wp := writerPool{size: r.maxChunkSize()} streamNo := 0
tws := make([]*bufio.Writer, r.maxStreams()-1)
for chunk := range chunks.Of(l.Size, r.maxChunkSize()) { for chunk := range chunks.Of(l.Size, r.maxChunkSize()) {
if ctx.Err() != nil {
break
}
ticket := q.Take() ticket := q.Take()
bufIdx := streamNo % len(tws)
streamNo++
g.Go(func() (err error) { g.Go(func() (err error) {
defer func() { defer func() {
if err != nil { if err != nil {
@@ -540,18 +492,23 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil { if err != nil {
return err return err
} }
err := func() error { err := func() error {
req := req.Clone(req.Context()) req := req.Clone(req.Context())
req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk)) req.Header.Set("Range", fmt.Sprintf("bytes=%s", chunk))
res, err := sendRequest(r.client(), req) res, err := doOK(r.client(), req)
if err != nil { if err != nil {
return err return err
} }
defer res.Body.Close() defer res.Body.Close()
tw := wp.get() tw := tws[bufIdx]
if tw == nil {
tw = bufio.NewWriterSize(nil, int(r.maxChunkSize()))
tws[bufIdx] = tw
}
tw.Reset(ticket) tw.Reset(ticket)
defer wp.put(tw) defer tw.Reset(nil) // release ticket
_, err = io.CopyN(tw, res.Body, chunk.Size()) _, err = io.CopyN(tw, res.Body, chunk.Size())
if err != nil { if err != nil {
@@ -593,14 +550,10 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified // Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model. // before attempting to unlink the model.
func (r *Registry) Unlink(name string) (ok bool, _ error) { func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
n, err := r.parseName(name) n := r.completeName(name)
if err != nil { if !n.IsFullyQualified() {
return false, err return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
c, err := r.cache()
if err != nil {
return false, err
} }
return c.Unlink(n.String()) return c.Unlink(n.String())
} }
@@ -610,9 +563,6 @@ type Manifest struct {
Name string `json:"-"` // the canonical name of the model Name string `json:"-"` // the canonical name of the model
Data []byte `json:"-"` // the raw data of the manifest Data []byte `json:"-"` // the raw data of the manifest
Layers []*Layer `json:"layers"` Layers []*Layer `json:"layers"`
// For legacy reasons, we still have to download the config layer.
Config *Layer `json:"config"`
} }
var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000") var emptyDigest, _ = blob.ParseDigest("sha256:0000000000000000000000000000000000000000000000000000000000000000")
@@ -676,18 +626,14 @@ type Layer struct {
Size int64 `json:"size"` Size int64 `json:"size"`
} }
// ResolveLocal resolves a name to a Manifest in the local cache. // ResolveLocal resolves a name to a Manifest in the local cache. The name is
func (r *Registry) ResolveLocal(name string) (*Manifest, error) { // parsed using [names.Split] but the scheme is ignored.
_, n, d, err := r.parseNameExtended(name) func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
if err != nil { _, n, d, err := parseName(name, r.Mask)
return nil, err
}
c, err := r.cache()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !d.IsValid() { if !d.IsValid() {
// No digest, so resolve the manifest by name.
d, err = c.Resolve(n.String()) d, err = c.Resolve(n.String())
if err != nil { if err != nil {
return nil, err return nil, err
@@ -696,7 +642,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
data, err := os.ReadFile(c.GetFile(d)) data, err := os.ReadFile(c.GetFile(d))
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name) return nil, fmt.Errorf("%w: %s", ErrManifestNotFound, name)
} }
return nil, err return nil, err
} }
@@ -709,7 +655,7 @@ func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
// Resolve resolves a name to a Manifest in the remote registry. // Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) { func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := r.parseNameExtended(name) scheme, n, d, err := parseName(name, r.Mask)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -719,7 +665,7 @@ func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error)
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d) manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
} }
res, err := r.send(ctx, "GET", manifestURL, nil) res, err := r.doOK(ctx, "GET", manifestURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -744,7 +690,7 @@ func (r *Registry) client() *http.Client {
} }
// newRequest constructs a new request, ready to use, with the given method, // newRequest constructs a new request, ready to use, with the given method,
// url, and body, pre-signed with client [Key] and [UserAgent]. // url, and body, presigned with client Key and UserAgent.
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, url, body) req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil { if err != nil {
@@ -763,17 +709,11 @@ func (r *Registry) newRequest(ctx context.Context, method, url string, body io.R
return req, nil return req, nil
} }
// sendRequest makes a request with the given client and request, and returns the // doOK makes a request with the given client and request, and returns the
// response if the status code is 200. If the status code is not 200, an Error // response if the status code is 200. If the status code is not 200, an Error
// is parsed from the response body and returned. If any other error occurs, it // is parsed from the response body and returned. If any other error occurs, it
// is returned. // is returned.
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) { func doOK(c *http.Client, r *http.Request) (*http.Response, error) {
defer func() {
if err != nil {
err = fmt.Errorf("request error %s: %w", r.URL, err)
}
}()
if r.URL.Scheme == "https+insecure" { if r.URL.Scheme == "https+insecure" {
// TODO(bmizerany): clone client.Transport, set // TODO(bmizerany): clone client.Transport, set
// InsecureSkipVerify, etc. // InsecureSkipVerify, etc.
@@ -816,26 +756,20 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
// Use the raw body if we can't parse it as an error object. // Use the raw body if we can't parse it as an error object.
re.Message = string(out) re.Message = string(out)
} }
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
return nil, ErrModelNotFound
}
re.Status = res.StatusCode re.Status = res.StatusCode
return nil, &re return nil, &re
} }
return res, nil return res, nil
} }
// send is a convenience method for making a request with newRequest and // doOK is a convenience method for making a request with newRequest and
// passing it to send with r.client(). // passing it to doOK with r.client().
func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { func (r *Registry) doOK(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
req, err := r.newRequest(ctx, method, path, body) req, err := r.newRequest(ctx, method, path, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return sendRequest(r.client(), req) return doOK(r.client(), req)
} }
// makeAuthToken creates an Ollama auth token for the given private key. // makeAuthToken creates an Ollama auth token for the given private key.
@@ -925,7 +859,7 @@ var supportedSchemes = []string{
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", ")) var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseNameExtended parses and validates an extended name, returning the scheme, name, // parseName parses and validates an extended name, returning the scheme, name,
// and digest. // and digest.
// //
// If the scheme is empty, scheme will be "https". If an unsupported scheme is // If the scheme is empty, scheme will be "https". If an unsupported scheme is
@@ -936,8 +870,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
// //
// If the name is not, once merged with the mask, fully qualified, // If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned. // [ErrNameInvalid] wrapped with a display friendly message is returned.
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) { func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s) scheme, name, digest := names.Split(s)
scheme = cmp.Or(scheme, "https") scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) { if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage) err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
@@ -960,58 +894,13 @@ func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ b
} }
} }
n, err := r.parseName(name) maskName := defaultMask
if err != nil { if mask != "" {
return "", names.Name{}, blob.Digest{}, err maskName = names.Parse(mask)
}
n := names.Merge(names.Parse(name), maskName)
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
} }
return scheme, n, d, nil return scheme, n, d, nil
} }
// splitExtended splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func splitExtended(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}
type writerPool struct {
size int64 // set by the caller
mu sync.Mutex
ws []*bufio.Writer
}
func (p *writerPool) get() *bufio.Writer {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.ws) == 0 {
return bufio.NewWriterSize(nil, int(p.size))
}
w := p.ws[len(p.ws)-1]
p.ws = p.ws[:len(p.ws)-1]
return w
}
func (p *writerPool) put(w *bufio.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
w.Reset(nil)
p.ws = append(p.ws, w)
}

View File

@@ -2,7 +2,6 @@ package ollama
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -73,7 +72,6 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// To simulate a network error, pass a handler that returns a 499 status code. // To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) { func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper() t.Helper()
c, err := blob.Open(t.TempDir()) c, err := blob.Open(t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -87,14 +85,13 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
} }
r := &Registry{ r := &Registry{
Cache: c,
HTTPClient: &http.Client{ HTTPClient: &http.Client{
Transport: recordRoundTripper(h), Transport: recordRoundTripper(h),
}, },
} }
link := func(name string, manifest string) { link := func(name string, manifest string) {
n, err := r.parseName(name) _, n, _, err := parseName(name, r.Mask)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -154,55 +151,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
} }
func TestPushZero(t *testing.T) { func TestPushZero(t *testing.T) {
rc, _ := newClient(t, okHandler) rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), "empty", nil) err := rc.Push(t.Context(), c, "empty", nil)
if !errors.Is(err, ErrManifestInvalid) { if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid) t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
} }
} }
func TestPushSingle(t *testing.T) { func TestPushSingle(t *testing.T) {
rc, _ := newClient(t, okHandler) rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), "single", nil) err := rc.Push(t.Context(), c, "single", nil)
testutil.Check(t, err) testutil.Check(t, err)
} }
func TestPushMultiple(t *testing.T) { func TestPushMultiple(t *testing.T) {
rc, _ := newClient(t, okHandler) rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), "multiple", nil) err := rc.Push(t.Context(), c, "multiple", nil)
testutil.Check(t, err) testutil.Check(t, err)
} }
func TestPushNotFound(t *testing.T) { func TestPushNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r) t.Errorf("unexpected request: %v", r)
}) })
err := rc.Push(t.Context(), "notfound", nil) err := rc.Push(t.Context(), c, "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) { if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist) t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
} }
} }
func TestPushNullLayer(t *testing.T) { func TestPushNullLayer(t *testing.T) {
rc, _ := newClient(t, nil) rc, c := newClient(t, nil)
err := rc.Push(t.Context(), "null", nil) err := rc.Push(t.Context(), c, "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") { if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err) t.Errorf("err = %v; want invalid manifest", err)
} }
} }
func TestPushSizeMismatch(t *testing.T) { func TestPushSizeMismatch(t *testing.T) {
rc, _ := newClient(t, nil) rc, c := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context()) ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, "sizemismatch", nil) got := rc.Push(ctx, c, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") { if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got) t.Errorf("err = %v; want size mismatch", got)
} }
} }
func TestPushInvalid(t *testing.T) { func TestPushInvalid(t *testing.T) {
rc, _ := newClient(t, nil) rc, c := newClient(t, nil)
err := rc.Push(t.Context(), "invalid", nil) err := rc.Push(t.Context(), c, "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") { if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err) t.Errorf("err = %v; want invalid manifest", err)
} }
@@ -210,7 +207,7 @@ func TestPushInvalid(t *testing.T) {
func TestPushExistsAtRemote(t *testing.T) { func TestPushExistsAtRemote(t *testing.T) {
var pushed bool var pushed bool
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") { if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed { if !pushed {
// First push. Return an uploadURL. // First push. Return an uploadURL.
@@ -238,35 +235,35 @@ func TestPushExistsAtRemote(t *testing.T) {
check := testutil.Checker(t) check := testutil.Checker(t)
err := rc.Push(ctx, "single", nil) err := rc.Push(ctx, c, "single", nil)
check(err) check(err)
if !errors.Is(errors.Join(errs...), nil) { if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached}) t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
} }
err = rc.Push(ctx, "single", nil) err = rc.Push(ctx, c, "single", nil)
check(err) check(err)
} }
func TestPushRemoteError(t *testing.T) { func TestPushRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") { if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500) w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`) io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return return
} }
}) })
got := rc.Push(t.Context(), "single", nil) got := rc.Push(t.Context(), c, "single", nil)
checkErrCode(t, got, 500, "blob_error") checkErrCode(t, got, 500, "blob_error")
} }
func TestPushLocationError(t *testing.T) { func TestPushLocationError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x") w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
}) })
got := rc.Push(t.Context(), "single", nil) got := rc.Push(t.Context(), c, "single", nil)
wantContains := "invalid upload URL" wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) { if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains) t.Errorf("err = %v; want to contain %v", got, wantContains)
@@ -274,14 +271,14 @@ func TestPushLocationError(t *testing.T) {
} }
func TestPushUploadRoundtripError(t *testing.T) { func TestPushUploadRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" { if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload w.WriteHeader(499) // force RoundTrip error on upload
return return
} }
w.Header().Set("Location", "http://blob.store/blobs/123") w.Header().Set("Location", "http://blob.store/blobs/123")
}) })
got := rc.Push(t.Context(), "single", nil) got := rc.Push(t.Context(), c, "single", nil)
if !errors.Is(got, errRoundTrip) { if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip) t.Errorf("got = %v; want %v", got, errRoundTrip)
} }
@@ -297,20 +294,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
os.Remove(c.GetFile(l.Digest)) os.Remove(c.GetFile(l.Digest))
}, },
}) })
got := rc.Push(ctx, "single", nil) got := rc.Push(ctx, c, "single", nil)
if !errors.Is(got, fs.ErrNotExist) { if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got) t.Errorf("got = %v; want fs.ErrNotExist", got)
} }
} }
func TestPushCommitRoundtripError(t *testing.T) { func TestPushCommitRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") { if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected") panic("unexpected")
} }
w.WriteHeader(499) // force RoundTrip error w.WriteHeader(499) // force RoundTrip error
}) })
err := rc.Push(t.Context(), "zero", nil) err := rc.Push(t.Context(), c, "zero", nil)
if !errors.Is(err, errRoundTrip) { if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip) t.Errorf("err = %v; want %v", err, errRoundTrip)
} }
@@ -324,8 +321,8 @@ func checkNotExist(t *testing.T, err error) {
} }
func TestRegistryPullInvalidName(t *testing.T) { func TestRegistryPullInvalidName(t *testing.T) {
rc, _ := newClient(t, nil) rc, c := newClient(t, nil)
err := rc.Pull(t.Context(), "://") err := rc.Pull(t.Context(), c, "://")
if !errors.Is(err, ErrNameInvalid) { if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid) t.Errorf("err = %v; want %v", err, ErrNameInvalid)
} }
@@ -340,10 +337,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
} }
for _, resp := range cases { for _, resp := range cases {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp) io.WriteString(w, resp)
}) })
err := rc.Pull(t.Context(), "x") err := rc.Pull(t.Context(), c, "x")
if !errors.Is(err, ErrManifestInvalid) { if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err) t.Errorf("err = %v; want invalid manifest", err)
} }
@@ -366,18 +363,18 @@ func TestRegistryPullNotCached(t *testing.T) {
}) })
// Confirm that the layer does not exist locally // Confirm that the layer does not exist locally
_, err := rc.ResolveLocal("model") _, err := rc.ResolveLocal(c, "model")
checkNotExist(t, err) checkNotExist(t, err)
_, err = c.Get(d) _, err = c.Get(d)
checkNotExist(t, err) checkNotExist(t, err)
err = rc.Pull(t.Context(), "model") err = rc.Pull(t.Context(), c, "model")
check(err) check(err)
mw, err := rc.Resolve(t.Context(), "model") mw, err := rc.Resolve(t.Context(), "model")
check(err) check(err)
mg, err := rc.ResolveLocal("model") mg, err := rc.ResolveLocal(c, "model")
check(err) check(err)
if !reflect.DeepEqual(mw, mg) { if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg) t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -402,7 +399,7 @@ func TestRegistryPullNotCached(t *testing.T) {
func TestRegistryPullCached(t *testing.T) { func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists") cached := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") { if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called w.WriteHeader(499) // should not be called
return return
@@ -425,7 +422,7 @@ func TestRegistryPullCached(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel() defer cancel()
err := rc.Pull(ctx, "single") err := rc.Pull(ctx, c, "single")
testutil.Check(t, err) testutil.Check(t, err)
want := []int64{6} want := []int64{6}
@@ -438,30 +435,30 @@ func TestRegistryPullCached(t *testing.T) {
} }
func TestRegistryPullManifestNotFound(t *testing.T) { func TestRegistryPullManifestNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
}) })
err := rc.Pull(t.Context(), "notfound") err := rc.Pull(t.Context(), c, "notfound")
checkErrCode(t, err, 404, "") checkErrCode(t, err, 404, "")
} }
func TestRegistryPullResolveRemoteError(t *testing.T) { func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`) io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
}) })
err := rc.Pull(t.Context(), "single") err := rc.Pull(t.Context(), c, "single")
checkErrCode(t, err, 500, "an_error") checkErrCode(t, err, 500, "an_error")
} }
func TestRegistryPullResolveRoundtripError(t *testing.T) { func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") { if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error w.WriteHeader(499) // force RoundTrip error
return return
} }
}) })
err := rc.Pull(t.Context(), "single") err := rc.Pull(t.Context(), c, "single")
if !errors.Is(err, errRoundTrip) { if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip) t.Errorf("err = %v; want %v", err, errRoundTrip)
} }
@@ -514,7 +511,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
// Check that we pull all layers that we can. // Check that we pull all layers that we can.
err := rc.Pull(ctx, "mixed") err := rc.Pull(ctx, c, "mixed")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -532,7 +529,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
} }
func TestRegistryPullChunking(t *testing.T) { func TestRegistryPullChunking(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) { rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range")) t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" { if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store. // The production registry redirects to the blob store.
@@ -570,7 +567,7 @@ func TestRegistryPullChunking(t *testing.T) {
}, },
}) })
err := rc.Pull(ctx, "remote") err := rc.Pull(ctx, c, "remote")
testutil.Check(t, err) testutil.Check(t, err)
want := []int64{0, 3, 6} want := []int64{0, 3, 6}
@@ -608,7 +605,7 @@ func TestInsecureSkipVerify(t *testing.T) {
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name) url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url) _, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") { if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verification failure", err) t.Errorf("err = %v; want cert verifiction failure", err)
} }
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name) url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
@@ -712,16 +709,25 @@ func TestErrorUnmarshal(t *testing.T) {
// //
// It is only for testing error messages, not that all invalids and valids are // It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest. // covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameExtendedErrors(t *testing.T) { func TestParseNameErrors(t *testing.T) {
cases := []struct { cases := []struct {
name string name string
err error err error
want string want string
}{} }{
{"x", nil, ""},
{"x@", nil, ""},
{"", ErrNameInvalid, `invalid or missing name: ""`},
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
}
var r Registry
for _, tt := range cases { for _, tt := range cases {
_, _, _, err := r.parseNameExtended(tt.name) _, _, _, err := parseName(tt.name, DefaultMask)
if !errors.Is(err, tt.err) { if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err) t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
} }
@@ -730,89 +736,3 @@ func TestParseNameExtendedErrors(t *testing.T) {
} }
} }
} }
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
// confirm linked
_, err := rc.ResolveLocal("single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// unlink
_, err = rc.Unlink("single")
testutil.Check(t, err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}

View File

@@ -6,20 +6,13 @@ import (
// Trace is a set of functions that are called to report progress during blob // Trace is a set of functions that are called to report progress during blob
// downloads and uploads. // downloads and uploads.
//
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
// and [Registry.Pull].
type Trace struct { type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to // Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads. // report the progress of blob uploads and downloads.
// //
// The n argument is the number of bytes transferred so far, and err is // It is called once at the beginning of the download with a zero n and
// any error that has occurred. If n == 0, and err is nil, the download // then once per read operation with the number of bytes read so far,
// or upload has just started. If err is [ErrCached], the download or // and an error if any.
// upload has been skipped because the blob is already present in the
// local cache or remote registry, respectively. Otherwise, if err is
// non-nil, the download or upload has failed. When l.Size == n, and
// err is nil, the download or upload has completed.
// //
// A function assigned must be safe for concurrent use. The function is // A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run. // called synchronously and so should not block or take long to run.

View File

@@ -63,28 +63,25 @@ func main() {
} }
flag.Parse() flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
ctx := context.Background() ctx := context.Background()
err := func() error { err = func() error {
switch cmd := flag.Arg(0); cmd { switch cmd := flag.Arg(0); cmd {
case "pull": case "pull":
rc, err := ollama.DefaultRegistry() return cmdPull(ctx, rc, c)
if err != nil {
log.Fatal(err)
}
return cmdPull(ctx, rc)
case "push": case "push":
rc, err := ollama.DefaultRegistry() return cmdPush(ctx, rc, c)
if err != nil {
log.Fatal(err)
}
return cmdPush(ctx, rc)
case "import": case "import":
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
return cmdImport(ctx, c) return cmdImport(ctx, c)
default: default:
if cmd == "" { if cmd == "" {
@@ -102,7 +99,7 @@ func main() {
} }
} }
func cmdPull(ctx context.Context, rc *ollama.Registry) error { func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
model := flag.Arg(1) model := flag.Arg(1)
if model == "" { if model == "" {
flag.Usage() flag.Usage()
@@ -148,7 +145,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry) error {
errc := make(chan error) errc := make(chan error)
go func() { go func() {
errc <- rc.Pull(ctx, model) errc <- rc.Pull(ctx, c, model)
}() }()
t := time.NewTicker(time.Second) t := time.NewTicker(time.Second)
@@ -164,7 +161,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry) error {
} }
} }
func cmdPush(ctx context.Context, rc *ollama.Registry) error { func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
args := flag.Args()[1:] args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError) flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.") flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
@@ -180,7 +177,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry) error {
} }
from := cmp.Or(*flagFrom, model) from := cmp.Or(*flagFrom, model)
m, err := rc.ResolveLocal(from) m, err := rc.ResolveLocal(c, from)
if err != nil { if err != nil {
return err return err
} }
@@ -206,7 +203,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry) error {
}, },
}) })
return rc.Push(ctx, model, &ollama.PushParams{ return rc.Push(ctx, c, model, &ollama.PushParams{
From: from, From: from,
}) })
} }

View File

@@ -1,5 +1,3 @@
//go:build goexperiment.synctest
package backoff package backoff
import ( import (

View File

@@ -1,5 +1,3 @@
//go:build goexperiment.synctest
package syncs package syncs
import ( import (

View File

@@ -7,12 +7,9 @@ import (
"cmp" "cmp"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"sync"
"time"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
@@ -30,15 +27,12 @@ import (
// directly to the blob disk cache. // directly to the blob disk cache.
type Local struct { type Local struct {
Client *ollama.Registry // required Client *ollama.Registry // required
Cache *blob.DiskCache // required
Logger *slog.Logger // required Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by // Fallback, if set, is used to handle requests that are not handled by
// this handler. // this handler.
Fallback http.Handler Fallback http.Handler
// Prune, if set, is called to prune the local disk cache after a model
// is deleted.
Prune func() error // optional
} }
// serverError is like ollama.Error, but with a Status field for the HTTP // serverError is like ollama.Error, but with a Status field for the HTTP
@@ -113,8 +107,6 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/api/delete": case "/api/delete":
return false, s.handleDelete(rec, r) return false, s.handleDelete(rec, r)
case "/api/pull":
return false, s.handlePull(rec, r)
default: default:
if s.Fallback != nil { if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r) s.Fallback.ServeHTTP(rec, r)
@@ -207,107 +199,13 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if err != nil { if err != nil {
return err return err
} }
ok, err := s.Client.Unlink(p.model()) ok, err := s.Client.Unlink(s.Cache, p.model())
if err != nil { if err != nil {
return err return err
} }
if !ok { if !ok {
return &serverError{404, "not_found", "model not found"} return &serverError{404, "not_found", "model not found"}
} }
if s.Prune == nil {
return nil
}
return s.Prune()
}
type progressUpdateJSON struct {
Status string `json:"status"`
Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
}
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
if err != nil {
return err
}
maybeFlush := func() {
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
}
defer maybeFlush()
var mu sync.Mutex
enc := json.NewEncoder(w)
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
mu.Lock()
defer mu.Unlock()
// TODO(bmizerany): coalesce these updates; writing per
// update is expensive
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Status: "pulling",
Total: l.Size,
Completed: n,
})
},
})
done := make(chan error, 1)
go func() {
// TODO(bmizerany): continue to support non-streaming responses
done <- s.Client.Pull(ctx, p.model())
}()
func() {
t := time.NewTicker(100 * time.Millisecond)
defer t.Stop()
for {
select {
case <-t.C:
mu.Lock()
maybeFlush()
mu.Unlock()
case err := <-done:
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {
status = fmt.Sprintf("error: model %q not found", p.model())
enc.Encode(progressUpdateJSON{Status: status})
} else {
status = fmt.Sprintf("error: %v", err)
enc.Encode(progressUpdateJSON{Status: status})
}
return
}
// These final updates are not strictly necessary, because they have
// already happened at this point. Our pull handler code used to do
// these steps after, not during, the pull, and they were slow, so we
// wanted to provide feedback to users what was happening. For now, we
// keep them to not jar users who are used to seeing them. We can phase
// them out with a new and nicer UX later. One without progress bars
// and digests that no one cares about.
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return
}
}
}()
return nil return nil
} }

View File

@@ -1,27 +1,17 @@
package registry package registry
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"fmt"
"io"
"io/fs"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"regexp" "regexp"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/ollama/ollama/server/internal/cache/blob" "github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil" "github.com/ollama/ollama/server/internal/testutil"
"golang.org/x/tools/txtar"
_ "embed"
) )
type panicTransport struct{} type panicTransport struct{}
@@ -40,7 +30,7 @@ type bytesResetter interface {
Reset() Reset()
} }
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local { func newTestServer(t *testing.T) *Local {
t.Helper() t.Helper()
dir := t.TempDir() dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models")) err := os.CopyFS(dir, os.DirFS("testdata/models"))
@@ -51,26 +41,11 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
client := panicOnRoundTrip
if upstreamRegistry != nil {
s := httptest.NewTLSServer(upstreamRegistry)
t.Cleanup(s.Close)
tr := s.Client().Transport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
}
client = &http.Client{Transport: tr}
}
rc := &ollama.Registry{ rc := &ollama.Registry{
Cache: c, HTTPClient: panicOnRoundTrip,
HTTPClient: client,
Mask: "example.com/library/_:latest",
} }
l := &Local{ l := &Local{
Cache: c,
Client: rc, Client: rc,
Logger: testutil.Slogger(t), Logger: testutil.Slogger(t),
} }
@@ -110,9 +85,9 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
func TestServerDelete(t *testing.T) { func TestServerDelete(t *testing.T) {
check := testutil.Checker(t) check := testutil.Checker(t)
s := newTestServer(t, nil) s := newTestServer(t)
_, err := s.Client.ResolveLocal("smol") _, err := s.Client.ResolveLocal(s.Cache, "smol")
check(err) check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`) got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
@@ -120,7 +95,7 @@ func TestServerDelete(t *testing.T) {
t.Fatalf("Code = %d; want 200", got.Code) t.Fatalf("Code = %d; want 200", got.Code)
} }
_, err = s.Client.ResolveLocal("smol") _, err = s.Client.ResolveLocal(s.Cache, "smol")
if err == nil { if err == nil {
t.Fatal("expected smol to have been deleted") t.Fatal("expected smol to have been deleted")
} }
@@ -152,105 +127,8 @@ func TestServerDelete(t *testing.T) {
} }
} }
//go:embed testdata/registry.txt
var registryTXT []byte
var registryFS = sync.OnceValue(func() fs.FS {
// Txtar gets hung up on \r\n line endings, so we need to convert them
// to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a)
if err != nil {
panic(err)
}
return fsys
})
func TestServerPull(t *testing.T) {
modelsHandler := http.FileServerFS(registryFS())
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/library/BOOM/manifests/latest":
w.WriteHeader(999)
io.WriteString(w, `{"error": "boom"}`)
case "/v2/library/unknown/manifests/latest":
w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default:
t.Logf("serving file: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r)
}
})
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Fatalf("! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Fatalf("! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"}
`)
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "POST", "/api/pull", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "POST", "/api/pull", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""}
!verifying
!writing
!success
`)
}
func TestServerUnknownPath(t *testing.T) { func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t, nil) s := newTestServer(t)
got := s.send(t, "DELETE", "/api/unknown", `{}`) got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found") checkErrorResponse(t, got, 404, "not_found", "not found")
} }

View File

@@ -1,22 +0,0 @@
-- v2/library/smol/manifests/latest --
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": {
"mediaType": "application/vnd.docker.container.image.v1+json",
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
"size": 3
},
"layers": [
{
"mediaType": "application/vnd.ollama.image.model",
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
"size": 5
}
]
}
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
GGUF
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
{}

View File

@@ -10,6 +10,7 @@ import (
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
@@ -26,7 +27,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var system []api.Message var system []api.Message
isMllama := checkMllamaModelFamily(m) isMllama := checkMllamaModelFamily(m)
isGemma3 := checkGemma3ModelFamily(m)
var imageNumTokens int var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -41,7 +41,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- { for i := n; i >= 0; i-- {
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 { if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages return "", nil, errTooManyImages
} }
@@ -93,7 +93,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var imgData llm.ImageData var imgData llm.ImageData
if isMllama { if isMllama {
if len(m.ProjectorPaths) == 0 { if envconfig.NewEngine() {
imgData = llm.ImageData{ imgData = llm.ImageData{
ID: len(images), ID: len(images),
Data: i, Data: i,
@@ -158,12 +158,3 @@ func checkMllamaModelFamily(m *Model) bool {
} }
return false return false
} }
func checkGemma3ModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "gemma3" {
return true
}
}
return false
}

View File

@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
@@ -42,12 +43,6 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
var useClient2 = experimentEnabled("client2")
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct { type Server struct {
@@ -211,7 +206,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
images := make([]llm.ImageData, len(req.Images)) images := make([]llm.ImageData, len(req.Images))
for i := range req.Images { for i := range req.Images {
if isMllama && len(model.ProjectorPaths) > 0 { if isMllama && !envconfig.NewEngine() {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
@@ -435,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
kvData, _, err := getModelData(m.ModelPath, false) kvData, err := getKVData(m.ModelPath, false)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -483,7 +478,8 @@ func (s *Server) EmbedHandler(c *gin.Context) {
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
return return
} }
@@ -544,7 +540,8 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
embedding, err := r.Embedding(c.Request.Context(), req.Prompt) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())}) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embedding: %v", err)})
return return
} }
@@ -848,23 +845,16 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String()) fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String() resp.Modelfile = sb.String()
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) kvData, err := getKVData(m.ModelPath, req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
} }
delete(kvData, "general.name") delete(kvData, "general.name")
delete(kvData, "tokenizer.chat_template") delete(kvData, "tokenizer.chat_template")
resp.ModelInfo = kvData resp.ModelInfo = kvData
tensorData := make([]api.Tensor, len(tensors.Items()))
for cnt, t := range tensors.Items() {
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
}
resp.Tensors = tensorData
if len(m.ProjectorPaths) > 0 { if len(m.ProjectorPaths) > 0 {
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose) projectorData, err := getKVData(m.ProjectorPaths[0], req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -874,17 +864,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return resp, nil return resp, nil
} }
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) { func getKVData(digest string, verbose bool) (ggml.KV, error) {
maxArraySize := 0 maxArraySize := 0
if verbose { if verbose {
maxArraySize = -1 maxArraySize = -1
} }
data, err := llm.LoadModel(digest, maxArraySize) kvData, err := llm.LoadModel(digest, maxArraySize)
if err != nil { if err != nil {
return nil, ggml.Tensors{}, err return nil, err
} }
kv := data.KV() kv := kvData.KV()
if !verbose { if !verbose {
for k := range kv { for k := range kv {
@@ -894,7 +884,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
} }
} }
return kv, data.Tensors(), nil return kv, nil
} }
func (s *Server) ListHandler(c *gin.Context) { func (s *Server) ListHandler(c *gin.Context) {
@@ -1139,7 +1129,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
} }
} }
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig() corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true corsConfig.AllowBrowserExtensions = true
@@ -1184,7 +1174,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.HEAD("/api/tags", s.ListHandler) r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler) r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler) r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler)
// Create // Create
r.POST("/api/create", s.CreateHandler) r.POST("/api/create", s.CreateHandler)
@@ -1206,19 +1195,15 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
if rc != nil { // wrap old with new
// wrap old with new rs := &registry.Local{
rs := &registry.Local{ Cache: c,
Client: rc, Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r, Fallback: r,
Prune: PruneLayers,
}
return rs, nil
} }
return r, nil return rs, nil
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
@@ -1273,20 +1258,19 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()} s := &Server{addr: ln.Addr()}
var rc *ollama.Registry c, err := ollama.DefaultCache()
if useClient2 { if err != nil {
var err error return err
rc, err = ollama.DefaultRegistry()
if err != nil {
return err
}
} }
rc, err := ollama.DefaultRegistry()
h, err := s.GenerateRoutes(rc)
if err != nil { if err != nil {
return err return err
} }
h, err := s.GenerateRoutes(c, rc)
if err != nil {
return err
}
http.Handle("/", h) http.Handle("/", h)
ctx, done := context.WithCancel(context.Background()) ctx, done := context.WithCancel(context.Background())

View File

@@ -23,6 +23,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -489,6 +490,11 @@ func TestRoutes(t *testing.T) {
modelsDir := t.TempDir() modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir) t.Setenv("OLLAMA_MODELS", modelsDir)
c, err := blob.Open(modelsDir)
if err != nil {
t.Fatalf("failed to open models dir: %v", err)
}
rc := &ollama.Registry{ rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward, // This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended // surfacing any code contacting ollama.com we do not intended
@@ -505,7 +511,7 @@ func TestRoutes(t *testing.T) {
} }
s := &Server{} s := &Server{}
router, err := s.GenerateRoutes(rc) router, err := s.GenerateRoutes(c, rc)
if err != nil { if err != nil {
t.Fatalf("failed to generate routes: %v", err) t.Fatalf("failed to generate routes: %v", err)
} }