Compare commits

..

1 Commits

Author SHA1 Message Date
Patrick Devine
c10a40db99 parser: tidy up parameter/message parsing
This change addresses how parameters and messages are handled while parsing a Modelfile.
Currently a `MESSAGE` command creates a string that looks like "<role>: <message>" which
then has to be re-parsed, and `PARAMETER` ends up being the default data type which makes
it difficult to add other multi-part commands.

This change introduces a Message and a Parameter type which properly handle properties
such as the role and the name of the parameter.
2025-09-15 18:09:05 -07:00
81 changed files with 855 additions and 4348 deletions

1
.gitignore vendored
View File

@@ -6,7 +6,6 @@
dist dist
build build
.cache .cache
.gocache
*.exe *.exe
.idea .idea
test_data test_data

View File

@@ -1,7 +1,6 @@
# vim: filetype=dockerfile # vim: filetype=dockerfile
ARG FLAVOR=${TARGETARCH} ARG FLAVOR=${TARGETARCH}
ARG PARALLEL=8
ARG ROCMVERSION=6.3.3 ARG ROCMVERSION=6.3.3
ARG JETPACK5VERSION=r35.4.1 ARG JETPACK5VERSION=r35.4.1
@@ -35,51 +34,46 @@ ENV LDFLAGS=-s
FROM base AS cpu FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++ RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \ cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \ && cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel ${PARALLEL} && cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11 FROM base AS cuda-11
ARG CUDA11VERSION=11.8 ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \ cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \ && cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\ cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \ && cmake --build --parallel --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-13 FROM base AS cuda-13
ARG CUDA13VERSION=13.0 ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \ cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \ && cmake --build --parallel --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip --parallel 8
FROM base AS rocm-6 FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \ cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \ && cmake --build --parallel --preset 'ROCm 6' \
&& cmake --install build --component HIP --strip --parallel ${PARALLEL} && cmake --install build --component HIP --strip --parallel 8
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
ARG CMAKEVERSION ARG CMAKEVERSION
@@ -87,11 +81,10 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 5' \ cmake --preset 'JetPack 5' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \ && cmake --build --parallel --preset 'JetPack 5' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip --parallel 8
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6 FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
ARG CMAKEVERSION ARG CMAKEVERSION
@@ -99,11 +92,10 @@ RUN apt-get update && apt-get install -y curl ccache \
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1 && curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json . COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ARG PARALLEL
RUN --mount=type=cache,target=/root/.ccache \ RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'JetPack 6' \ cmake --preset 'JetPack 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ && cmake --build --parallel --preset 'JetPack 6' \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL} && cmake --install build --component CUDA --strip --parallel 8
FROM base AS build FROM base AS build
WORKDIR /go/src/github.com/ollama/ollama WORKDIR /go/src/github.com/ollama/ollama

View File

@@ -45,12 +45,6 @@ func checkError(resp *http.Response, body []byte) error {
return nil return nil
} }
if resp.StatusCode == http.StatusUnauthorized {
authError := AuthorizationError{StatusCode: resp.StatusCode}
json.Unmarshal(body, &authError)
return authError
}
apiError := StatusError{StatusCode: resp.StatusCode} apiError := StatusError{StatusCode: resp.StatusCode}
err := json.Unmarshal(body, &apiError) err := json.Unmarshal(body, &apiError)
@@ -220,8 +214,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanner.Buffer(scanBuf, maxBufferSize) scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() { for scanner.Scan() {
var errorResponse struct { var errorResponse struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
SigninURL string `json:"signin_url,omitempty"`
} }
bts := scanner.Bytes() bts := scanner.Bytes()
@@ -229,13 +222,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf("unmarshal: %w", err) return fmt.Errorf("unmarshal: %w", err)
} }
if response.StatusCode == http.StatusUnauthorized { if response.StatusCode >= http.StatusBadRequest {
return AuthorizationError{
StatusCode: response.StatusCode,
Status: response.Status,
SigninURL: errorResponse.SigninURL,
}
} else if response.StatusCode >= http.StatusBadRequest {
return StatusError{ return StatusError{
StatusCode: response.StatusCode, StatusCode: response.StatusCode,
Status: response.Status, Status: response.Status,
@@ -441,21 +428,3 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil return version.Version, nil
} }
// Signout will signout a client for a local ollama server.
func (c *Client) Signout(ctx context.Context) error {
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
}
// Disconnect will disconnect an ollama instance from ollama.com.
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
}
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
var resp UserResponse
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -11,8 +11,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@@ -38,19 +36,6 @@ func (e StatusError) Error() string {
} }
} }
type AuthorizationError struct {
StatusCode int
Status string
SigninURL string `json:"signin_url"`
}
func (e AuthorizationError) Error() string {
if e.Status != "" {
return e.Status
}
return "something went wrong, please see the ollama server logs for details"
}
// ImageData represents the raw binary data of an image file. // ImageData represents the raw binary data of an image file.
type ImageData []byte type ImageData []byte
@@ -328,29 +313,13 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
// Model is the model name that generated the response. Model string `json:"model"`
Model string `json:"model"` CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
// RemoteModel is the name of the upstream model that generated the response.
RemoteModel string `json:"remote_model,omitempty"`
// RemoteHost is the URL of the upstream Ollama host that generated the response.
RemoteHost string `json:"remote_host,omitempty"`
// CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"`
// Message contains the message or part of a message from the model.
Message Message `json:"message"`
// Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`
// DoneReason is the reason the model stopped generating text.
DoneReason string `json:"done_reason,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
Metrics Metrics
} }
@@ -360,6 +329,13 @@ type DebugInfo struct {
ImageCount int `json:"image_count,omitempty"` ImageCount int `json:"image_count,omitempty"`
} }
// DebugTemplateResponse is returned when _debug_render_only is set to true
type DebugTemplateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
DebugInfo DebugInfo `json:"_debug_info"`
}
type Metrics struct { type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"` TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"`
@@ -455,47 +431,18 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create]. // CreateRequest is the request passed to [Client.Create].
type CreateRequest struct { type CreateRequest struct {
// Model is the model name to create. Model string `json:"model"`
Model string `json:"model"` Stream *bool `json:"stream,omitempty"`
// Stream specifies whether the response is streaming; it is true by default.
Stream *bool `json:"stream,omitempty"`
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
Quantize string `json:"quantize,omitempty"` Quantize string `json:"quantize,omitempty"`
// From is the name of the model or file to use as the source. From string `json:"from,omitempty"`
From string `json:"from,omitempty"` Files map[string]string `json:"files,omitempty"`
Adapters map[string]string `json:"adapters,omitempty"`
// RemoteHost is the URL of the upstream ollama API for the model (if any). Template string `json:"template,omitempty"`
RemoteHost string `json:"remote_host,omitempty"` License any `json:"license,omitempty"`
System string `json:"system,omitempty"`
// Files is a map of files include when creating the model. Parameters map[string]any `json:"parameters,omitempty"`
Files map[string]string `json:"files,omitempty"` Messages []Message `json:"messages,omitempty"`
// Adapters is a map of LoRA adapters to include when creating the model.
Adapters map[string]string `json:"adapters,omitempty"`
// Template is the template used when constructing a request to the model.
Template string `json:"template,omitempty"`
// License is a string or list of strings for licenses.
License any `json:"license,omitempty"`
// System is the system prompt for the model.
System string `json:"system,omitempty"`
// Parameters is a map of hyper-parameters which are applied to the model.
Parameters map[string]any `json:"parameters,omitempty"`
// Messages is a list of messages added to the model before chat and generation requests.
Messages []Message `json:"messages,omitempty"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
// Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"`
// Deprecated: set the model name with Model instead // Deprecated: set the model name with Model instead
Name string `json:"name"` Name string `json:"name"`
@@ -533,12 +480,8 @@ type ShowResponse struct {
Parameters string `json:"parameters,omitempty"` Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"`
ProjectorInfo map[string]any `json:"projector_info,omitempty"` ProjectorInfo map[string]any `json:"projector_info,omitempty"`
Tensors []Tensor `json:"tensors,omitempty"` Tensors []Tensor `json:"tensors,omitempty"`
@@ -597,14 +540,12 @@ type ProcessResponse struct {
// ListModelResponse is a single model description in [ListResponse]. // ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct { type ListModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
RemoteModel string `json:"remote_model,omitempty"` ModifiedAt time.Time `json:"modified_at"`
RemoteHost string `json:"remote_host,omitempty"` Size int64 `json:"size"`
ModifiedAt time.Time `json:"modified_at"` Digest string `json:"digest"`
Size int64 `json:"size"` Details ModelDetails `json:"details,omitempty"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
} }
// ProcessModelResponse is a single model description in [ProcessResponse]. // ProcessModelResponse is a single model description in [ProcessResponse].
@@ -628,12 +569,6 @@ type GenerateResponse struct {
// Model is the model name that generated the response. // Model is the model name that generated the response.
Model string `json:"model"` Model string `json:"model"`
// RemoteModel is the name of the upstream model that generated the response.
RemoteModel string `json:"remote_model,omitempty"`
// RemoteHost is the URL of the upstream Ollama host that generated the response.
RemoteHost string `json:"remote_host,omitempty"`
// CreatedAt is the timestamp of the response. // CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
@@ -657,8 +592,6 @@ type GenerateResponse struct {
Metrics Metrics
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
} }
// ModelDetails provides details about a model. // ModelDetails provides details about a model.
@@ -671,18 +604,6 @@ type ModelDetails struct {
QuantizationLevel string `json:"quantization_level"` QuantizationLevel string `json:"quantization_level"`
} }
// UserResponse provides information about a user.
type UserResponse struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Bio string `json:"bio,omitempty"`
AvatarURL string `json:"avatarurl,omitempty"`
FirstName string `json:"firstname,omitempty"`
LastName string `json:"lastname,omitempty"`
Plan string `json:"plan,omitempty"`
}
// Tensor describes the metadata for a given tensor. // Tensor describes the metadata for a given tensor.
type Tensor struct { type Tensor struct {
Name string `json:"name"` Name string `json:"name"`

View File

@@ -18,13 +18,21 @@ import (
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func GetPublicKey() (string, error) { func keyPath() (string, error) {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
@@ -51,12 +59,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
} }
func Sign(ctx context.Context, bts []byte) (string, error) { func Sign(ctx context.Context, bts []byte) (string, error) {
home, err := os.UserHomeDir() keyPath, err := keyPath()
if err != nil { if err != nil {
return "", err return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View File

@@ -47,8 +47,6 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support // ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" { if name == "" {
@@ -288,17 +286,7 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
Think: opts.Think, Think: opts.Think,
} }
return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error { return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
if r.RemoteModel != "" && opts.ShowConnect {
p.StopAndClear()
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
}
}
return nil
})
} }
func StopHandler(cmd *cobra.Command, args []string) error { func StopHandler(cmd *cobra.Command, args []string) error {
@@ -319,10 +307,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
opts := runOptions{ opts := runOptions{
Model: args[0], Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color", WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{}, Options: map[string]any{},
ShowConnect: true,
} }
format, err := cmd.Flags().GetString("format") format, err := cmd.Flags().GetString("format")
@@ -380,7 +367,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
prompts = append([]string{string(in)}, prompts...) prompts = append([]string{string(in)}, prompts...)
opts.ShowConnect = false
opts.WordWrap = false opts.WordWrap = false
interactive = false interactive = false
} }
@@ -447,15 +433,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if interactive { if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if sErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, sErr.SigninURL)
}
return nil
}
return err return err
} }
@@ -476,59 +453,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts) return generate(cmd, opts)
} }
func SigninHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
user, err := client.Whoami(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
fmt.Println()
if aErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
}
return err
}
if user != nil && user.Name != "" {
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
fmt.Println()
return nil
}
return nil
}
func SignoutHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
err = client.Signout(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You are not signed in to ollama.com")
fmt.Println()
return nil
} else {
return err
}
}
fmt.Println("You have signed out of ollama.com")
fmt.Println()
return nil
}
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@@ -581,8 +505,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
errStr := strings.ToLower(err.Error()) if strings.Contains(err.Error(), "access denied") {
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
} }
return err return err
@@ -616,14 +539,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
for _, m := range models.Models { for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) { if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
var size string data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
if m.RemoteModel != "" {
size = "-"
} else {
size = format.HumanBytes(m.Size)
}
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
} }
} }
@@ -708,8 +624,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
KeepAlive: &api.Duration{Duration: 0}, KeepAlive: &api.Duration{Duration: 0},
} }
if err := loadOrUnloadModel(cmd, opts); err != nil { if err := loadOrUnloadModel(cmd, opts); err != nil {
if !strings.Contains(strings.ToLower(err.Error()), "not found") { if !strings.Contains(err.Error(), "not found") {
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0]) return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
} }
} }
@@ -820,36 +736,12 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
} }
tableRender("Model", func() (rows [][]string) { tableRender("Model", func() (rows [][]string) {
if resp.RemoteHost != "" {
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
}
if resp.ModelInfo != nil { if resp.ModelInfo != nil {
arch := resp.ModelInfo["general.architecture"].(string) arch := resp.ModelInfo["general.architecture"].(string)
rows = append(rows, []string{"", "architecture", arch}) rows = append(rows, []string{"", "architecture", arch})
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
var paramStr string rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
if resp.Details.ParameterSize != "" { rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
paramStr = resp.Details.ParameterSize
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
if f, ok := v.(float64); ok {
paramStr = format.HumanNumber(uint64(f))
}
}
rows = append(rows, []string{"", "parameters", paramStr})
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
if f, ok := v.(float64); ok {
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
}
}
} else { } else {
rows = append(rows, []string{"", "architecture", resp.Details.Family}) rows = append(rows, []string{"", "architecture", resp.Details.Family})
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
@@ -1097,7 +989,6 @@ type runOptions struct {
KeepAlive *api.Duration KeepAlive *api.Duration
Think *api.ThinkValue Think *api.ThinkValue
HideThinking bool HideThinking bool
ShowConnect bool
} }
type displayResponseState struct { type displayResponseState struct {
@@ -1653,22 +1544,6 @@ func NewCLI() *cobra.Command {
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
signinCmd := &cobra.Command{
Use: "signin",
Short: "Sign in to ollama.com",
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SigninHandler,
}
signoutCmd := &cobra.Command{
Use: "signout",
Short: "Sign out from ollama.com",
Args: cobra.ExactArgs(0),
PreRunE: checkServerHeartbeat,
RunE: SignoutHandler,
}
listCmd := &cobra.Command{ listCmd := &cobra.Command{
Use: "list", Use: "list",
Aliases: []string{"ls"}, Aliases: []string{"ls"},
@@ -1763,8 +1638,6 @@ func NewCLI() *cobra.Command {
stopCmd, stopCmd,
pullCmd, pullCmd,
pushCmd, pushCmd,
signinCmd,
signoutCmd,
listCmd, listCmd,
psCmd, psCmd,
copyCmd, copyCmd,

View File

@@ -3,7 +3,6 @@ package cmd
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -305,8 +304,6 @@ func TestDeleteHandler(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
errPayload := `{"error":"model '%s' not found"}`
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
} }
return return
} }
@@ -349,7 +346,7 @@ func TestDeleteHandler(t *testing.T) {
} }
err := DeleteHandler(cmd, []string{"test-model-not-found"}) err := DeleteHandler(cmd, []string{"test-model-not-found"})
if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") { if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err) t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
} }
} }
@@ -502,7 +499,7 @@ func TestPushHandler(t *testing.T) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{ err := json.NewEncoder(w).Encode(map[string]string{
"error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}", "error": "access denied",
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -525,10 +522,6 @@ func TestPushHandler(t *testing.T) {
defer mockServer.Close() defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)
initializeKeypair()
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "") cmd.Flags().Bool("insecure", false, "")

View File

@@ -96,7 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 { func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind() kind := st.tensorBase.Kind()
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 { if st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16 kind = tensorKindBF16
} }

View File

@@ -230,65 +230,3 @@ func TestSafetensors(t *testing.T) {
}) })
} }
} }
func TestSafetensorKind(t *testing.T) {
tests := []struct {
name string
st safetensor
expected uint32
}{
{
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindBF16,
},
{
name: "BF16 dtype with v. prefix should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "v.weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "BF16",
},
expected: tensorKindFP16,
},
{
name: "BF16 dtype with FP32 base kind should return FP32",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10}, // will default to FP32
},
dtype: "BF16",
},
expected: tensorKindFP32,
},
{
name: "Non-BF16 dtype should return base kind",
st: safetensor{
tensorBase: &tensorBase{
name: "weight.matrix",
shape: []uint64{10, 10}, // will default to FP16
},
dtype: "FP16",
},
expected: tensorKindFP16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.st.Kind()
if result != tt.expected {
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
}
})
}
}

View File

@@ -16,7 +16,7 @@ import (
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices. // Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK") var CudaTegra string = os.Getenv("JETSON_JETPACK")
func cudaVariant(gpuInfos []CudaGPUInfo) string { func cudaVariant(gpuInfo CudaGPUInfo) string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" { if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" { if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".") ver := strings.Split(CudaTegra, ".")
@@ -45,19 +45,12 @@ func cudaVariant(gpuInfos []CudaGPUInfo) string {
} }
} }
// Check GPU compute capability FIRST, lowest common denominator if multi-gpu if gpuInfo.DriverMajor < 13 {
for _, gpuInfo := range gpuInfos {
if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
return "v12"
}
}
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025) // The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor)) if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
}
return "v12" return "v12"
} }
return "v13" return "v13"

View File

@@ -284,8 +284,18 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor gpuInfo.DriverMinor = driverMinor
variant := cudaVariant(gpuInfo)
// Start with our bundled libraries
if variant != "" {
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err == nil {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
}
}
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.Variant = variant
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) { if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
unsupportedGPUs = append(unsupportedGPUs, unsupportedGPUs = append(unsupportedGPUs,
@@ -323,24 +333,6 @@ func GetGPUInfo() GpuInfoList {
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo) cudaGPUs = append(cudaGPUs, gpuInfo)
} }
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
variant := cudaVariant(cudaGPUs)
var variantPath string
// Start with our bundled libraries
if variant != "" {
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
if _, err := os.Stat(variantPath); err != nil {
variantPath = ""
}
}
for i := range cudaGPUs {
cudaGPUs[i].Variant = variant
if variantPath != "" {
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
}
}
} }
// Intel // Intel

View File

@@ -1,40 +0,0 @@
# Cloud
| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud).
## Cloud Models
[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldnt fit on a personal computer.
Ollama currently supports the following cloud models, with more coming soon:
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `deepseek-v3.1:671b-cloud`
- `qwen3-coder:480b-cloud`
### Get started
To run a cloud model, open the terminal and run:
```
ollama run gpt-oss:120b-cloud
```
To run cloud models with integrations that work with Ollama, first download the cloud model:
```
ollama pull qwen3-coder:480b-cloud
```
Then sign in to Ollama:
```
ollama signin
```
Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling.
## Cloud API access
Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud).

107
docs/turbo.md Normal file
View File

@@ -0,0 +1,107 @@
# Turbo
>  Turbo is preview
Ollamas [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
Currently, the following models are available in Turbo:
- `gpt-oss:20b`
- `gpt-oss:120b`
## Get started
### Ollama for macOS & Windows
Download Ollama
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
- Click on **Turbo**. Youll be prompted to create an account or sign in
### Ollamas CLI
- [Sign up](https://ollama.com/signup) for an Ollama account
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
On macOS and Linux:
```shell
cat ~/.ollama/id_ed25519.pub
```
On Windows:
```
type "%USERPROFILE%\.ollama\id_ed25519.pub"
```
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
```shell
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
```
### Ollamas Python library
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
- [Sign up](https://ollama.com/signup) for an Ollama account
- Create an API key by visiting https://ollama.com/settings/keys
```python
from ollama import Client
client = Client(
host="https://ollama.com",
headers={'Authorization': '<api key>'}
)
messages = [
{
'role': 'user',
'content': 'Why is the sky blue?',
},
]
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
print(part['message']['content'], end='', flush=True)
```
### Ollamas JavaScript library
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
- [Sign up](https://ollama.com/signup) for an Ollama account
- Create an API key by visiting https://ollama.com/settings/keys
```typescript
import { Ollama } from 'ollama';
const ollama = new Ollama({
host: 'https://ollama.com',
headers: {
Authorization: "Bearer <api key>"
}
});
const response = await ollama.chat({
model: 'gpt-oss:120b',
messages: [{ role: 'user', content: 'Explain quantum computing' }],
stream: true
});
for await (const part of response) {
process.stdout.write(part.message.content)
}
```
### Community integrations
Turbo mode is also compatible with several community integrations.
#### Open WebUI
- Go to **settings** → **Admin settings** → **Connections**
- Under **Ollama API,** click **+**
- For the **URL** put `https://ollama.com`
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
- Click **Save**
Now, if you navigate to the model selector, Turbo models should be available under **External**.

View File

@@ -134,17 +134,6 @@ func LoadTimeout() (loadTimeout time.Duration) {
return loadTimeout return loadTimeout
} }
func Remotes() []string {
var r []string
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
if raw == "" {
r = []string{"ollama.com"}
} else {
r = strings.Split(raw, ",")
}
return r
}
func Bool(k string) func() bool { func Bool(k string) func() bool {
return func() bool { return func() bool {
if s := Var(k); s != "" { if s := Var(k); s != "" {
@@ -281,7 +270,6 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
// Informational // Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

View File

@@ -243,8 +243,6 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3", "gemma3",
"gemma3n", "gemma3n",
"mistral3", "mistral3",
"qwen3",
"qwen3moe",
"llama4", "llama4",
"mllama", "mllama",
"qwen25vl", "qwen25vl",

View File

@@ -1,7 +1,6 @@
package harmony package harmony
import ( import (
"encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"strings" "strings"
@@ -266,8 +265,6 @@ type HarmonyMessageHandler struct {
state harmonyMessageState state harmonyMessageState
HarmonyParser *HarmonyParser HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap FunctionNameMap *FunctionNameMap
toolAccumulator *HarmonyToolCallAccumulator
convertedTools map[string]struct{}
} }
// NewHarmonyMessageHandler creates a new message handler // NewHarmonyMessageHandler creates a new message handler
@@ -280,7 +277,6 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>", HeaderEndTag: "<|message|>",
}, },
FunctionNameMap: NewFunctionNameMap(), FunctionNameMap: NewFunctionNameMap(),
convertedTools: make(map[string]struct{}),
} }
} }
@@ -388,85 +384,8 @@ func NewFunctionNameMap() *FunctionNameMap {
} }
} }
// Init initializes the handler with tools and optional last message
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{
MessageStartTag: "<|start|>",
MessageEndTag: "<|end|>",
HeaderEndTag: "<|message|>",
}
}
// Handle prefill for chat mode
if lastMessage != nil {
h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
} else {
h.HarmonyParser.AddImplicitStart()
}
// Initialize tool accumulator
h.toolAccumulator = h.CreateToolParser()
// Process tools and return renamed versions
if len(tools) == 0 {
return tools
}
processedTools := make([]api.Tool, len(tools))
copy(processedTools, tools)
for i, tool := range processedTools {
if tool.Function.Name != "" {
processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
h.convertedTools[tool.Function.Name] = struct{}{}
}
}
return processedTools
}
// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls
func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
content, thinking, toolContent := h.AddContent(s, h.toolAccumulator)
if toolContent != "" {
h.toolAccumulator.Add(toolContent)
}
// tool calls always happen one at a time, and always at the end of a message,
// so for simplicity we defer parsing them until we know we're done
if done {
toolName, raw := h.toolAccumulator.Drain()
if toolName != nil {
name := strings.TrimPrefix(*toolName, "functions.")
name = h.FunctionNameMap.OriginalFromConverted(name)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(raw), &args); err != nil {
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
}
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
}
}
return content, thinking, calls, nil
}
// HasToolSupport implements the Parser interface
func (h *HarmonyMessageHandler) HasToolSupport() bool {
return true
}
// HasThinkingSupport implements the Parser interface
func (h *HarmonyMessageHandler) HasThinkingSupport() bool {
return true
}
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
harmonyFunctionName := m.deriveName(userFunctionName) harmonyFunctionName := m.deriveName(userFunctionName)
// built-in functions should not be renamed
if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" {
harmonyFunctionName = userFunctionName
}
m.userToHarmony[userFunctionName] = harmonyFunctionName m.userToHarmony[userFunctionName] = harmonyFunctionName
m.harmonyToUser[harmonyFunctionName] = userFunctionName m.harmonyToUser[harmonyFunctionName] = userFunctionName
return harmonyFunctionName return harmonyFunctionName

View File

@@ -513,7 +513,6 @@ func TestFunctionConvertAndAdd(t *testing.T) {
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}}, {name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}}, {name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}}, {name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
{name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}},
} }
for i, tt := range tests { for i, tt := range tests {

View File

@@ -12,6 +12,3 @@ The integration tests have 2 modes of operating.
> [!IMPORTANT] > [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. > Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`

View File

@@ -22,12 +22,13 @@ func TestAPIGenerate(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: blueSkyPrompt, Prompt: "why is the sky blue? be brief",
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
@@ -119,14 +120,14 @@ func TestAPIGenerate(t *testing.T) {
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()
atLeastOne := false atLeastOne := false
for _, resp := range blueSkyExpected { for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) { if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true atLeastOne = true
break break
} }
} }
if !atLeastOne { if !atLeastOne {
t.Errorf("none of %v found in %s", blueSkyExpected, response) t.Errorf("none of %v found in %s", anyResp, response)
} }
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for generate") t.Error("outer test context done while waiting for generate")
@@ -180,7 +181,7 @@ func TestAPIChat(t *testing.T) {
Messages: []api.Message{ Messages: []api.Message{
{ {
Role: "user", Role: "user",
Content: blueSkyPrompt, Content: "why is the sky blue? be brief",
}, },
}, },
Options: map[string]interface{}{ Options: map[string]interface{}{
@@ -188,6 +189,7 @@ func TestAPIChat(t *testing.T) {
"seed": 123, "seed": 123,
}, },
} }
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
@@ -277,14 +279,14 @@ func TestAPIChat(t *testing.T) {
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()
atLeastOne := false atLeastOne := false
for _, resp := range blueSkyExpected { for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) { if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true atLeastOne = true
break break
} }
} }
if !atLeastOne { if !atLeastOne {
t.Errorf("none of %v found in %s", blueSkyExpected, response) t.Errorf("none of %v found in %s", anyResp, response)
} }
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for chat") t.Error("outer test context done while waiting for chat")

View File

@@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) {
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: blueSkyPrompt, Prompt: "why is the sky blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
GenerateTestHelper(ctx, t, req, blueSkyExpected) GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
} }
func TestUnicode(t *testing.T) { func TestUnicode(t *testing.T) {
@@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) {
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: smol,
Prompt: blueSkyPrompt, Prompt: "why is the sky blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
GenerateTestHelper(ctx, t, req, blueSkyExpected) GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
} }

View File

@@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err) t.Fatalf("PullIfMissing failed: %v", err)
} }
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second) DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
} }
// Send multiple generate requests with prior context and ensure the response is coherant and expected // Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestParallelGenerateWithHistory(t *testing.T) { func TestGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests() req, resp := GenerateRequests()
numParallel := 2 numParallel := 2
@@ -113,48 +113,8 @@ func TestParallelGenerateWithHistory(t *testing.T) {
wg.Wait() wg.Wait()
} }
// Send generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: rainbowPrompt,
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"num_ctx": 16384,
},
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
for i := 0; i < len(rainbowFollowups); i++ {
req.Prompt = rainbowFollowups[i]
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
}
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected // Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestParallelChatWithHistory(t *testing.T) { func TestChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests() req, resp := ChatRequests()
numParallel := 2 numParallel := 2
@@ -204,55 +164,3 @@ func TestParallelChatWithHistory(t *testing.T) {
} }
wg.Wait() wg.Wait()
} }
// Send generate requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
req := api.ChatRequest{
Model: smol,
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]any{
"num_ctx": 16384,
},
Messages: []api.Message{
{
Role: "user",
Content: rainbowPrompt,
},
},
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
for i := 0; i < len(rainbowFollowups); i++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
req.Messages = append(req.Messages,
*assistant,
api.Message{Role: "user", Content: rainbowFollowups[i]},
)
assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
}
}

View File

@@ -8,7 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -45,8 +44,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
} }
res, err := embeddingTestHelper(ctx, client, t, req) res, err := embeddingTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("error: %v", err)
} }
if len(res.Embedding) != 384 { if len(res.Embedding) != 384 {
@@ -74,8 +74,9 @@ func TestAllMiniLMEmbed(t *testing.T) {
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("error: %v", err)
} }
if len(res.Embeddings) != 1 { if len(res.Embeddings) != 1 {
@@ -111,8 +112,9 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, client, t, req)
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("error: %v", err)
} }
if len(res.Embeddings) != 2 { if len(res.Embeddings) != 2 {
@@ -154,135 +156,93 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
truncTrue, truncFalse := true, false truncTrue, truncFalse := true, false
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ type testReq struct {
Model: "all-minilm", Name string
Input: "why", Request api.EmbedRequest
})
if err != nil {
t.Fatal(err)
} }
cases := []struct { reqs := []testReq{
name string
request api.EmbedRequest
check func(*api.EmbedResponse, error)
}{
{ {
name: "target truncation", Name: "Target Truncation",
request: api.EmbedRequest{ Request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why", Input: "why",
}, },
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
}, },
{ {
name: "default truncate", Name: "Default Truncate",
request: api.EmbedRequest{ Request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3}, Options: map[string]any{"num_ctx": 1},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
}, },
}, },
{ {
name: "explicit truncate", Name: "Explicit Truncate",
request: api.EmbedRequest{ Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
},
},
{
name: "truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncTrue, Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1}, Options: map[string]any{"num_ctx": 1},
}, },
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
}, },
} }
for _, req := range cases { res := make(map[string]*api.EmbedResponse)
t.Run(req.name, func(t *testing.T) {
req.check(embedTestHelper(ctx, client, t, req.request)) for _, req := range reqs {
}) response, err := embedTestHelper(ctx, client, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
} }
} }
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }
return client.Embeddings(ctx, &req) response, err := client.Embeddings(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
} }
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
t.Helper()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }
return client.Embed(ctx, &req) response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
} }

View File

@@ -4,9 +4,7 @@ package integration
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"os"
"testing" "testing"
"time" "time"
@@ -22,7 +20,6 @@ func TestLibraryModelsGenerate(t *testing.T) {
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
chatModels := libraryChatModels chatModels := libraryChatModels
for _, model := range chatModels { for _, model := range chatModels {
@@ -33,26 +30,16 @@ func TestLibraryModelsGenerate(t *testing.T) {
if err := PullIfMissing(ctx, client, model); err != nil { if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err) t.Fatalf("pull failed %s", err)
} }
if targetArch != "" {
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
arch := resp.ModelInfo["general.architecture"].(string)
if arch != targetArch {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
}
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: model, Model: model,
Prompt: blueSkyPrompt, Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0.1, "temperature": 0.1,
"seed": 123, "seed": 123,
}, },
} }
anyResp := blueSkyExpected anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
// Special cases // Special cases
if model == "duckdb-nsql" { if model == "duckdb-nsql" {
anyResp = []string{"select", "from"} anyResp = []string{"select", "from"}

View File

@@ -68,13 +68,14 @@ func TestModelsGenerate(t *testing.T) {
// TODO - fiddle with context size // TODO - fiddle with context size
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: model, Model: model,
Prompt: blueSkyPrompt, Prompt: "why is the sky blue?",
Options: map[string]interface{}{ Options: map[string]interface{}{
"temperature": 0, "temperature": 0,
"seed": 123, "seed": 123,
}, },
} }
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second) anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
}) })
} }
} }

View File

@@ -40,18 +40,6 @@ var (
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) { func TestModelsPerf(t *testing.T) {
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
doModelPerfTest(t, ollamaEngineChatModels)
} else {
doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
}
}
func TestLibraryModelsPerf(t *testing.T) {
doModelPerfTest(t, libraryChatModels)
}
func doModelPerfTest(t *testing.T, chatModels []string) {
softTimeout, hardTimeout := getTimeouts(t) softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -77,12 +65,14 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
} }
longPrompt := "summarize the following: " + string(data) longPrompt := "summarize the following: " + string(data)
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels { for _, model := range chatModels {
if !strings.Contains(model, ":") {
model = model + ":latest"
}
t.Run(model, func(t *testing.T) { t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout { if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime") t.Skip("skipping remaining tests to avoid excessive runtime")
@@ -98,9 +88,6 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
} }
arch := resp.ModelInfo["general.architecture"].(string) arch := resp.ModelInfo["general.architecture"].(string)
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
if targetArch != "" && arch != targetArch {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
if maxVram > 0 { if maxVram > 0 {
resp, err := client.List(ctx) resp, err := client.List(ctx)
@@ -164,8 +151,8 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
prompt string prompt string
anyResp []string anyResp []string
}{ }{
{blueSkyPrompt, blueSkyExpected}, {"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}}, {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
} }
var gpuPercent int var gpuPercent int
for _, tc := range testCases { for _, tc := range testCases {
@@ -254,12 +241,11 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
} }
} }
} }
// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL", "MODEL",
"CONTEXT", "CONTEXT",
"GPU PERCENT", "GPU PERCENT",
"APPROX PROMPT COUNT", "PROMPT COUNT",
"LOAD TIME", "LOAD TIME",
"PROMPT EVAL TPS", "PROMPT EVAL TPS",
"EVAL TPS", "EVAL TPS",
@@ -268,7 +254,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
model, model,
numCtx, numCtx,
gpuPercent, gpuPercent,
(resp.PromptEvalCount/10)*10, resp.PromptEvalCount,
float64(resp.LoadDuration)/1000000000.0, float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),

View File

@@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) {
stream := true stream := true
genReq := api.GenerateRequest{ genReq := api.GenerateRequest{
Model: newName, Model: newName,
Prompt: blueSkyPrompt, Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 3 * time.Second}, KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{ Options: map[string]any{
"seed": 42, "seed": 42,
@@ -88,13 +88,14 @@ func TestQuantization(t *testing.T) {
// Some smaller quantizations can cause models to have poor quality // Some smaller quantizations can cause models to have poor quality
// or get stuck in repetition loops, so we stop as soon as we have any matches // or get stuck in repetition loops, so we stop as soon as we have any matches
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
reqCtx, reqCancel := context.WithCancel(ctx) reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false atLeastOne := false
var buf bytes.Buffer var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error { genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response)) buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String()) fullResp := strings.ToLower(buf.String())
for _, resp := range blueSkyExpected { for _, resp := range anyResp {
if strings.Contains(fullResp, resp) { if strings.Contains(fullResp, resp) {
atLeastOne = true atLeastOne = true
t.Log(fullResp) t.Log(fullResp)

View File

@@ -256,29 +256,13 @@ var (
"snowflake-arctic-embed", "snowflake-arctic-embed",
"snowflake-arctic-embed2", "snowflake-arctic-embed2",
} }
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
rainbowPrompt = "how do rainbows form? Be brief but factual in your reply"
rainbowFollowups = []string{
"Explain the physics involved in them. Be breif in your reply",
"Explain the chemistry involved in them. Be breif in your reply",
"Explain the quantum mechanics involved in them. Be breif in your reply",
"What are common myths related to them? Be brief in your reply",
"What are common fairytales related to them? Be brief in your reply",
"Can they form if there is no rain? Be breif in your reply",
"Can they form if there are no clouds? Be breif in your reply",
"Do they happen on other planets? Be brief in your reply",
}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"}
) )
func init() { func init() {
lifecycle.InitLogging() lifecycle.InitLogging()
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL") custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL")
if custom != "" { if custom != "" {
slog.Info("setting default test model to " + custom) slog.Info("setting smol test model to " + custom)
smol = custom smol = custom
} }
} }
@@ -593,11 +577,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
}, },
}, },
[][]string{ [][]string{
{"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"}, {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"}, {"water", "droplet", "refracted", "reflect", "color", "spectrum"},
{"fourth", "july", "declaration", "independence"}, {"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"}, {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
} }
} }

View File

@@ -5,8 +5,6 @@ import (
"io" "io"
"log/slog" "log/slog"
"path/filepath" "path/filepath"
"runtime"
"time"
) )
const LevelTrace slog.Level = -8 const LevelTrace slog.Level = -8
@@ -31,18 +29,10 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
})) }))
} }
type key string
func Trace(msg string, args ...any) { func Trace(msg string, args ...any) {
TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...) slog.Log(context.TODO(), LevelTrace, msg, args...)
} }
func TraceContext(ctx context.Context, msg string, args ...any) { func TraceContext(ctx context.Context, msg string, args ...any) {
if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) { slog.Log(ctx, LevelTrace, msg, args...)
skip, _ := ctx.Value(key("skip")).(int)
pc, _, _, _ := runtime.Caller(1 + skip)
record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
record.Add(args...)
logger.Handler().Handle(ctx, record)
}
} }

View File

@@ -430,13 +430,12 @@ type Tensor interface {
Sin(ctx Context) Tensor Sin(ctx Context) Tensor
Cos(ctx Context) Tensor Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor GELU(ctx Context) Tensor
SILU(ctx Context, up ...Tensor) Tensor QuickGELU(ctx Context) Tensor
RELU(ctx Context, up ...Tensor) Tensor SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor Sigmoid(ctx Context) Tensor
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor

View File

@@ -1431,46 +1431,35 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
} }
} }
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
if len(t2) > 0 { return &Tensor{
return &Tensor{ b: t.b,
b: t.b, t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
} }
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
b: t.b,
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
}
}
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t), t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
} }
} }
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
return &Tensor{ return &Tensor{
b: t.b, b: t.b,
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)), t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),

View File

@@ -26,7 +26,6 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
} }
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
if key != nil && value != nil { if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) { if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@@ -40,7 +39,6 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
} }
ctx.Forward(key, value)
if cache != nil { if cache != nil {
cache.Put(ctx, key, value) cache.Put(ctx, key, value)
} }

View File

@@ -11,32 +11,26 @@ const (
TypeMean TypeMean
TypeCLS TypeCLS
TypeLast TypeLast
TypeRank
TypeUnknown = 0xFFFFFFFE
TypeUnspecified = 0xFFFFFFFF
) )
func (t Type) String() string { func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
switch t { switch poolingType {
case TypeMean: case TypeNone:
return "Mean" return hiddenStates
case TypeCLS:
return "CLS"
case TypeLast:
return "Last"
default:
return "Unknown"
}
}
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
switch t {
case TypeMean: case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS: case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast: case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) panic("not implemented")
return hiddenStates case TypeRank:
panic("not implemented")
default: default:
panic("unknown pooling type") panic("not implemented")
} }
} }

View File

@@ -1,79 +0,0 @@
package pooling_test
import (
"bytes"
"os"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/discover"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn/pooling"
)
func setup(tb testing.TB, n int) ml.Backend {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
if err := fsggml.WriteGGUF(f, fsggml.KV{
"general.architecture": "test",
"test.block_count": uint32(1),
}, []*fsggml.Tensor{
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
}); err != nil {
tb.Fatal(err)
}
var gpuLayers ml.GPULayersList
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
gpuLayers = append(gpuLayers, ml.GPULayers{
ID: gpus[0].ID,
Layers: slices.Collect(func(yield func(int) bool) {
for i := range n {
if !yield(i) {
return
}
}
}),
})
}
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
if err != nil {
tb.Fatal(err)
}
return b
}
func TestForward(t *testing.T) {
cases := map[pooling.Type][]float32{
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
}
for typ, want := range cases {
t.Run(typ.String(), func(t *testing.T) {
b := setup(t, 99)
defer b.Close()
ctx := b.NewContext()
defer ctx.Close()
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
tt = typ.Forward(ctx, tt)
ctx.Forward(tt).Compute(tt)
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
t.Error(diff)
}
})
}
}

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"iter" "iter"
"log/slog" "log/slog"
"slices"
"strings" "strings"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
@@ -14,28 +13,16 @@ import (
) )
type BytePairEncoding struct { type BytePairEncoding struct {
vocab *Vocabulary pre *regexp2.Regexp
regexps []*regexp2.Regexp vocab *Vocabulary
} }
var _ TextProcessor = (*BytePairEncoding)(nil) var _ TextProcessor = (*BytePairEncoding)(nil)
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
if len(pretokenizers) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{ return BytePairEncoding{
pre: regexp2.MustCompile(pre, regexp2.None),
vocab: vocab, vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
for _, p := range pretokenizers {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
}
}),
} }
} }
@@ -48,36 +35,13 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool {
} }
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
parts := []string{s} return func(yield func(string) bool) {
for _, re := range bpe.regexps { for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
parts = slices.Collect(func(yield func(string) bool) { if !yield(m.String()) {
for _, part := range parts { break
r := []rune(part)
var offset int
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
if offset-m.Index != 0 {
if !yield(string(r[:m.Index])) {
return
}
}
if !yield(m.String()) {
return
}
offset = m.Index + m.Length
}
if offset < len(r) {
if !yield(string(r[offset:])) {
return
}
}
} }
}) }
} }
return slices.Values(parts)
} }
// fragment is a string fragment and their corresponding token IDs // fragment is a string fragment and their corresponding token IDs

View File

@@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding {
} }
return NewBytePairEncoding( return NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&Vocabulary{ &Vocabulary{
Values: tokens, Values: tokens,
Types: types, Types: types,
Merges: merges, Merges: merges,
}, },
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
) )
} }
@@ -282,41 +282,3 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}) })
} }
} }
func TestSplit(t *testing.T) {
cases := []struct {
name string
patterns,
want []string
}{
{
name: "default",
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
},
{
name: "unicode",
patterns: []string{
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
},
{
name: "individual digits",
patterns: []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
},
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
t.Errorf("no match (-theirs +ours):\n%s", diff)
}
})
}
}

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
_ "image/jpeg" _ "image/jpeg"
_ "image/png" _ "image/png"
"log/slog" "math"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
@@ -21,7 +21,6 @@ import (
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
) )
@@ -108,12 +107,23 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return nil, err return nil, err
} }
m, err := modelForArch(b.Config()) arch := b.Config().Architecture()
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(b.Config())
if err != nil { if err != nil {
return nil, err return nil, err
} }
base := Base{b: b, config: m.Config()} base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m) v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem())) v.Elem().Set(populateFields(base, v.Elem()))
return m, nil return m, nil
@@ -125,38 +135,30 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
meta, err := fsggml.Decode(r, -1) meta, err := fsggml.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return getTextProcessor(meta.KV())
}
m, err := modelForArch(meta.KV()) func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tp, ok := m.(TextProcessor) tp, ok := m.(TextProcessor)
if !ok { if !ok {
return nil, ErrUnsupportedTokenizer return nil, fmt.Errorf("%v is not a TextProcessor", m)
} }
return tp, nil return tp, nil
} }
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, ErrUnsupportedModel
}
return f(c)
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type() t := v.Type()
@@ -172,44 +174,35 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
// make a copy // make a copy
tagsCopy := tags tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" { if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, parseTag(tag)) tagsCopy = append(tagsCopy, ParseTags(tag))
} }
if tt == reflect.TypeOf((*Base)(nil)).Elem() { if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(base)) vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag, string, string) [][]string var fn func([]Tag) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { fn = func(tags []Tag) (names [][]string) {
if len(tags) > 0 { if len(tags) > 0 {
var names []string localNames := []string{tags[0].Name}
if tags[0].name != "" { localNames = append(localNames, tags[0].Alternate...)
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
names = append(names, prefix+n+suffix) for _, localName := range localNames {
} fullName := []string{localName}
} nested := fn(tags[1:])
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix) if len(nested) > 0 {
if len(names) == 0 { for _, rest := range nested {
// current tag has no name, use child names only names = append(names, append(fullName, rest...))
fullNames = append(fullNames, childNames...)
} else if len(childNames) == 0 {
// current tag has names but no children, create branches for each name
for _, name := range names {
fullNames = append(fullNames, []string{name})
}
} else {
// merge each name with each child
for _, name := range names {
for _, childName := range childNames {
fullNames = append(fullNames, append([]string{name}, childName...))
} }
} else {
names = append(names, fullName)
} }
} }
} }
return fullNames return names
} }
names := fn(tagsCopy, "", "") names := fn(tagsCopy)
for _, name := range names { for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor) logutil.Trace("found tensor", "", tensor)
@@ -223,9 +216,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
for i := range vv.Len() { for i := range vv.Len() {
vvv := vv.Index(i) vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
} else { } else {
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
} }
} }
} }
@@ -264,31 +257,18 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
} }
type Tag struct { type Tag struct {
name, Name string
// prefix and suffix are applied to child tags Alternate []string
prefix,
suffix string
alternatives []string
} }
func parseTag(s string) (tag Tag) { func ParseTags(s string) (tag Tag) {
parts := strings.Split(s, ",") parts := strings.Split(s, ",")
if len(parts) > 0 { if len(parts) > 0 {
tag.name = parts[0] tag.Name = parts[0]
for _, part := range parts[1:] { for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { if value, ok := strings.CutPrefix(part, "alt:"); ok {
// elevate alternative to primary if no primary given tag.Alternate = append(tag.Alternate, value)
tag.name = value
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
} else if ok {
tag.alternatives = append(tag.alternatives, value)
}
if value, ok := strings.CutPrefix(part, "pre:"); ok {
tag.prefix = value
}
if value, ok := strings.CutPrefix(part, "suf:"); ok {
tag.suffix = value
} }
} }
} }

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"errors"
"reflect" "reflect"
"slices" "slices"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
) )
func TestParseTags(t *testing.T) { func TestParseTags(t *testing.T) {
@@ -22,14 +23,14 @@ func TestParseTags(t *testing.T) {
{ {
value: "output", value: "output",
want: Tag{ want: Tag{
name: "output", Name: "output",
}, },
}, },
{ {
value: "output,alt:token_embd", value: "output,alt:token_embd",
want: Tag{ want: Tag{
name: "output", Name: "output",
alternatives: []string{ Alternate: []string{
"token_embd", "token_embd",
}, },
}, },
@@ -38,8 +39,8 @@ func TestParseTags(t *testing.T) {
for _, tt := range cases { for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) { t.Run(tt.value, func(t *testing.T) {
got := parseTag(tt.value) got := ParseTags(tt.value)
if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" { if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
} }
}) })
@@ -125,7 +126,6 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
Input *nn.Embedding `gguf:"input"` Input *nn.Embedding `gguf:"input"`
Output *nn.Linear `gguf:"output,alt:input"` Output *nn.Linear `gguf:"output,alt:input"`
Nested *nested `gguf:"nested"` Nested *nested `gguf:"nested"`
Tensor ml.Tensor `gguf:"leaf,alt:tensor"`
} }
var m fakeModel var m fakeModel
@@ -134,7 +134,6 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
names: []string{ names: []string{
"input.weight", "input.weight",
"nested.b.weight", "nested.b.weight",
"leaf",
}, },
}}, v.Elem())) }}, v.Elem()))
@@ -144,115 +143,44 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
Nested: &nested{ Nested: &nested{
Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}}, Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}},
}, },
Tensor: &fakeTensor{Name: "leaf"},
}, m); diff != "" { }, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
} }
} }
func TestPopulateFieldsPrefixSuffixName(t *testing.T) { func TestGetTextProcessor(t *testing.T) {
type fakeBlock struct { tp, err := getTextProcessor(fsggml.KV{})
A *nn.Linear `gguf:"a"` if err == nil {
B *nn.Linear `gguf:",pre:b_"` t.Error("expected error")
C *nn.Linear `gguf:",suf:_c"` } else if !strings.Contains(err.Error(), "unsupported model architecture") {
XY *nn.Linear `gguf:",pre:x_,suf:_y"` t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
} }
type fakeModel struct { models["dummy"] = func(fs.Config) (Model, error) {
Blocks []fakeBlock `gguf:"blk"` return notTextProcessorModel{}, nil
} }
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
m := fakeModel{ if err == nil {
Blocks: make([]fakeBlock, 2), t.Error("expected error")
} } else if !strings.Contains(err.Error(), "not a TextProcessor") {
v := reflect.ValueOf(&m) t.Errorf("unexpected error: %v", err)
v.Elem().Set(populateFields(Base{b: &fakeBackend{ } else if tp != nil {
names: []string{ t.Error("expected nil tp")
"blk.0.a.weight",
"blk.0.b_weight",
"blk.0.b_bias",
"blk.0.weight_c",
"blk.0.x_weight_y",
"blk.1.a.weight",
"blk.1.b_weight",
"blk.1.b_bias",
"blk.1.weight_c",
"blk.1.x_weight_y",
},
}}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Blocks: []fakeBlock{
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
},
{
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
},
},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
} }
} }
func TestModelForArch(t *testing.T) { type notTextProcessorModel struct{}
type fakeModel struct {
Model
}
type fakeEmbeddingModel struct { func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
Model panic("unimplemented")
} }
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil } func (notTextProcessorModel) Backend() ml.Backend {
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil } panic("unimplemented")
}
cases := []struct {
name string func (notTextProcessorModel) Config() config {
config fs.Config panic("unimplemented")
want any
err error
}{
{
name: "model",
config: fsggml.KV{
"general.architecture": "model",
},
want: fakeModel{},
},
{
name: "embedding",
config: fsggml.KV{
"general.architecture": "model",
"model.pooling_type": uint32(1),
},
want: fakeEmbeddingModel{},
},
{
name: "unsupported",
config: fsggml.KV{
"general.architecture": "unsupported",
},
err: ErrUnsupportedModel,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := modelForArch(tt.config)
if !errors.Is(err, tt.err) {
t.Fatal(err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
}
})
}
} }

View File

@@ -37,7 +37,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
} }
hiddenStates = m.poolingType.Forward(ctx, hiddenStates) hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
if m.normalize { if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
} }

View File

@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
attnValLen: int(c.Uint("attention.value_length")), attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0), ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.scaling.factor", 1.0), ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"), attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"), finalLogitSoftcap: c.Float("final_logit_softcapping"),
}, },
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
} }
type MLP struct { type MLP struct {
@@ -138,7 +138,7 @@ type MLP struct {
} }
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }

View File

@@ -22,7 +22,7 @@ type embedModel struct {
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenStates = m.poolingType.Forward(ctx, hiddenStates) hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
for _, dense := range m.Dense { for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates) hiddenStates = dense.Forward(ctx, hiddenStates)
} }

View File

@@ -53,10 +53,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: 1, ropeScale: c.Float("rope.freq_scale", 1.0),
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
}, },
} }
@@ -87,7 +84,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling { if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +95,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps) k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -116,7 +113,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.TextConfig.ropeGlobalBase ropeBase = m.TextConfig.ropeGlobalBase
} }
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
} }
type TextMLP struct { type TextMLP struct {
@@ -126,7 +123,7 @@ type TextMLP struct {
} }
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }

View File

@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.ropeBaseLocal ropeBase = m.ropeBaseLocal
} }
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
} }
type TextScaledWordEmbedding struct { type TextScaledWordEmbedding struct {
@@ -170,7 +170,8 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
} }
active = d.PerLayerInputGate.Forward(ctx, active) active = d.PerLayerInputGate.Forward(ctx, active)
active = active.GELU(ctx, perLayerInput) active = active.GELU(ctx)
active = active.Mul(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active) active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
@@ -256,14 +257,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
query := attn.Query.Forward(ctx, hiddenStates) query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps) query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor var key, value ml.Tensor
if !sharedKV { if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates) key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps) key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates) value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
@@ -291,7 +292,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx) hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
} }
hiddenStates = hiddenStates.GELU(ctx, upStates) hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates) hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates return hiddenStates
} }
@@ -349,7 +350,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: c.Float("rope.freq_base", 1_000_000), ropeBase: c.Float("rope.freq_base", 1_000_000),
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000), ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
ropeScale: c.Float("rope.scaling.factor", 1.0), ropeScale: c.Float("rope.freq_scale", 1.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
activationSparsityScale: c.Floats("activation_sparsity_scale"), activationSparsityScale: c.Floats("activation_sparsity_scale"),

View File

@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
} }
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7) hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights) experts = experts.Mul(ctx, routingWeights)
@@ -227,6 +227,17 @@ func New(c fs.Config) (model.Model, error) {
m := Transformer{ m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -239,15 +250,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
strings.Join([]string{
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
`\p{N}{1,3}`,
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
`\s*[\r\n]+`,
`\s+(?!\S)`,
`\s+`,
}, "|"),
), ),
Options: Options{ Options: Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),

View File

@@ -2,6 +2,7 @@ package llama
import ( import (
"cmp" "cmp"
"fmt"
"math" "math"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
@@ -22,80 +23,51 @@ type Options struct {
type Model struct { type Model struct {
model.Base model.Base
model.TextProcessor model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"` TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"` Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"` OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"` Output *nn.Linear `gguf:"output,alt:token_embd"`
Options *Options
} }
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
if c.Uint("expert_count") > 0 { // This model currently only supports the gpt2 tokenizer
// TODO: support mixtures of experts if c.String("tokenizer.ggml.model") == "llama" {
return nil, model.ErrUnsupportedModel return nil, fmt.Errorf("unsupported tokenizer: llama")
} }
// Best effort detection of library/deepseek-coder model(s) which are incompatible
var processor model.TextProcessor if c.String("general.name") == "deepseek-ai" {
vocabulary := model.Vocabulary{ return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
} }
switch c.String("tokenizer.ggml.model") {
case "gpt2":
var pretokenizers []string
switch c.String("tokenizer.ggml.pre") {
case "default":
// no-op use the default bpe pretokenizer
case "qwen2":
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "refact":
pretokenizers = []string{
`\p{N}`,
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
}
case "tekken":
pretokenizers = []string{
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
default:
// use a llama-style pretokenizer
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
processor = model.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{ m := Model{
TextProcessor: processor, BytePairEncoding: model.NewBytePairEncoding(
Layers: make([]Layer, c.Uint("block_count")), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
Options: Options{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 1e5), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
}, },
} }
@@ -126,8 +98,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -136,7 +108,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
} }
type MLP struct { type MLP struct {
@@ -146,7 +118,7 @@ type MLP struct {
} }
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@@ -191,7 +163,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
outputs = batch.Outputs outputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -34,6 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer",
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -46,7 +48,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope { if useRope {
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
} }
if opts.useQKNorm { if opts.useQKNorm {
@@ -58,14 +58,14 @@ type TextMLP struct {
} }
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
type TextExperts struct { type TextExperts struct {
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"` Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"` Down *nn.Linear `gguf:"ffn_down_exps"`
} }
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores) hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.Forward(ctx, hiddenStates, experts) upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.Forward(ctx, hiddenStates, experts) gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ { for i := 1; i < opts.numExpertsUsed; i++ {
@@ -88,10 +88,22 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
return nextStates return nextStates
} }
// TextSharedExpert is TextMLP with different tensor names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct { type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"` Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts Experts *TextExperts
SharedExpert *TextMLP `gguf:",suf:_shexp"` SharedExpert *TextSharedExpert
} }
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -184,7 +196,7 @@ func newTextModel(c fs.Config) *TextModel {
numExpertsUsed: int(c.Uint("expert_used_count")), numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
noRopeInterval: int(c.Uint("no_rope_interval", 4)), noRopeInterval: int(c.Uint("no_rope_interval", 4)),
@@ -236,5 +248,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
} }

View File

@@ -33,6 +33,7 @@ var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := &Model{ m := &Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -45,7 +46,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
TextModel: newTextModel(c), TextModel: newTextModel(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale) k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
} }
type MLP struct { type MLP struct {
@@ -65,7 +65,7 @@ type MLP struct {
} }
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
}, },
} }
} }

View File

@@ -51,7 +51,7 @@ type VisionMLP struct {
} }
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }

View File

@@ -33,6 +33,7 @@ const (
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -45,7 +46,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
query := sa.Query.Forward(ctx, hiddenState) query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key := sa.Key.Forward(ctx, hiddenState) key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
} }
return key, nil return key, nil
@@ -58,7 +58,7 @@ type TextMLP struct {
} }
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor { func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"), crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
}, },
} }

View File

@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
value := attn.Value.Forward(ctx, hiddenStates) value := attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -59,7 +59,7 @@ type MLP struct {
} }
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
} }
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
@@ -139,6 +139,7 @@ func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")), Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -151,7 +152,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
Options: Options{ Options: Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
}, },
} }

View File

@@ -29,6 +29,7 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
m := &Model{ m := &Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -41,7 +42,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
TextModel: NewTextModel(c), TextModel: NewTextModel(c),
VisionModel: newVisionModel(c), VisionModel: newVisionModel(c),

View File

@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
originalContextLength: int(c.Uint("context_length", 128000)), originalContextLength: int(c.Uint("context_length", 128000)),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
}, },
} }
@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState) q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k := sa.Key.Forward(ctx, hiddenState) k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX()) k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching // Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
} }
// MLP implements the feed-forward network component with SwiGLU activation // MLP implements the feed-forward network component with SwiGLU activation
@@ -90,7 +90,7 @@ type MLP struct {
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating // Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension // Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState) return mlp.Down.Forward(ctx, hiddenState)
} }

View File

@@ -100,7 +100,8 @@ type VisionMLP struct {
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish) // Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates) gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }

View File

@@ -1,73 +0,0 @@
package qwen3
import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.BytePairEncoding
*Model
poolingType pooling.Type
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbed(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
for i := range layers {
layers[i].MLP = &dense{}
}
m := embedModel{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
),
Model: &Model{
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
},
},
poolingType: pooling.Type(c.Uint("pooling_type")),
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}

View File

@@ -30,10 +30,10 @@ func (o Options) headDim() int {
} }
type Attention struct { type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"` Query *nn.Linear `gguf:"attn_q"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"` Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"` Output *nn.Linear `gguf:"attn_output"`
} }
@@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = sa.QueryNorm.Forward(ctx, query, opts.eps) query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps) key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX()) key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
@@ -65,10 +65,10 @@ type MLP interface {
} }
type sparse struct { type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"` Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` Gate *nn.Linear `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"` Up *nn.Linear `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"` Down *nn.Linear `gguf:"ffn_down_exps"`
} }
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@@ -87,9 +87,13 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts)) upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.SILU(ctx)
hiddenStates = hiddenStates.Mul(ctx, upStates)
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights) experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
@@ -107,8 +111,7 @@ type dense struct {
} }
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates). hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates) return mlp.Down.Forward(ctx, hiddenStates)
} }
@@ -151,25 +154,14 @@ type Model struct {
*Options *Options
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates, err := m.forward(ctx, batch)
if err != nil {
return nil, err
}
return m.Output.Forward(ctx, hiddenStates), nil
}
// Forward implements model.Model. // Forward implements model.Model.
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers { for i, layer := range m.Layers {
if m.Cache != nil { m.Cache.SetLayer(i)
m.Cache.SetLayer(i)
}
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
@@ -179,11 +171,12 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
} }
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
} }
var _ model.Model = (*Model)(nil) var _ model.Model = (*Model)(nil)
@@ -200,6 +193,7 @@ func New(c fs.Config) (model.Model, error) {
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"), Types: c.Ints("tokenizer.ggml.token_type"),
@@ -212,7 +206,6 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")..., c.Ints("tokenizer.ggml.eos_token_ids")...,
), ),
}, },
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
), ),
Layers: layers, Layers: layers,
Options: &Options{ Options: &Options{
@@ -223,7 +216,7 @@ func New(c fs.Config) (model.Model, error) {
valueLength: int(c.Uint("attention.value_length")), valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"), ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.freq_scale", 1),
numExperts: int(c.Uint("expert_count")), numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")), numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true), normTopKProb: c.Bool("norm_top_k_prob", true),
@@ -237,5 +230,4 @@ func New(c fs.Config) (model.Model, error) {
func init() { func init() {
model.Register("qwen3", New) model.Register("qwen3", New)
model.Register("qwen3moe", New) model.Register("qwen3moe", New)
model.Register("qwen3_embed", newEmbed)
} }

View File

@@ -1,49 +0,0 @@
package parsers
import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type Parser interface {
// Init initializes the parser with tools and optional last message for chat prefill
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
// Add processes streamed content and returns parsed content, thinking, and tool calls
// The done flag indicates if this is the last chunk (used for draining accumulators)
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
HasToolSupport() bool
HasThinkingSupport() bool
}
func ParserForName(name string) Parser {
switch name {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
case "passthrough":
return &PassthroughParser{}
case "harmony":
return harmony.NewHarmonyMessageHandler()
default:
return nil
}
}
type PassthroughParser struct{}
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
return tools // passthrough doesn't modify tools
}
func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
return s, "", nil, nil
}
func (p *PassthroughParser) HasToolSupport() bool {
return false
}
func (p *PassthroughParser) HasThinkingSupport() bool {
return false
}

View File

@@ -1,453 +0,0 @@
package parsers
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"math"
"regexp"
"strconv"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type qwenParserState int
const (
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
const (
qwenParserState_LookingForToolStart qwenParserState = iota
qwenParserState_CollectingToolContent
)
type Qwen3CoderParser struct {
state qwenParserState
acc strings.Builder
tools []api.Tool
}
func (p *Qwen3CoderParser) HasToolSupport() bool {
return true
}
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false
}
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
return tools // Qwen doesn't modify tools
}
func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.acc.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var sb strings.Builder
for _, event := range events {
switch event := event.(type) {
case qwenEventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case qwenEventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here. See the note below about
// `qwenEvent`s for more details
sb.WriteString(event.content)
}
}
return sb.String(), "", toolCalls, nil
}
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
var all []qwenEvent
keepLooping := true
for keepLooping {
var events []qwenEvent
events, keepLooping = eat(p)
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
}
return all
}
// we use some internal event types in order to communicate between `Add` and
// `eat`. We do this to support interleaving content and parallel tool calls in
// the parser, even though qwen3-coder isn't supposed to do this. Our API
// doesn't currently support models outputting multiple messages in a turn, so
// we wouldn't be able to represent it yet, but there's no reason to prevent the
// parser from supporting it, especially for future models if they end up using
// a similar format.
type qwenEvent interface {
isQwenEvent()
}
type qwenEventRawToolCall struct {
raw string
}
type qwenEventContent struct {
content string
}
func (qwenEventContent) isQwenEvent() {}
func (qwenEventRawToolCall) isQwenEvent() {}
// eat consumes the parser's buffer, and returns a list of any unambiguous
// events from the current parser state. If the parser transitions to another
// state, it may have additional events to emit on the next call, which is what
// the second return value indicates
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
var events []qwenEvent
switch p.state {
case qwenParserState_LookingForToolStart:
if strings.Contains(p.acc.String(), toolOpenTag) {
// we found a full tool open tag, so we can emit the content before the
// tag, being sure to trim any trailing whitespace
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
if len(before) > 0 {
events = append(events, qwenEventContent{content: before})
}
after := split[1]
p.acc.Reset()
p.acc.WriteString(after)
p.state = qwenParserState_CollectingToolContent
return events, true
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
// we found a partial tool open tag, so we can emit the unambiguous part,
// which is the (trailing-whitespace trimmed) content before the partial
// tool open tag
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.acc.String()[:ambiguousStart]
ambiguous := p.acc.String()[ambiguousStart:]
p.acc.Reset()
p.acc.WriteString(ambiguous)
events = append(events, qwenEventContent{content: unambiguous})
return events, false
} else {
// we found content that is entirely not a tool call. We should withhold
// any trailing whitespace in case this is the end of the content
whitespaceLen := trailingWhitespaceLen(p.acc.String())
ambiguousStart := len(p.acc.String()) - whitespaceLen
unambiguous := p.acc.String()[:ambiguousStart]
ambiguous := p.acc.String()[ambiguousStart:]
p.acc.Reset()
p.acc.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, qwenEventContent{content: unambiguous})
}
return events, false
}
case qwenParserState_CollectingToolContent:
if strings.Contains(p.acc.String(), toolCloseTag) {
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
// remove any whitespace between the tool call and any content after it
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.acc.Reset()
p.acc.WriteString(after)
events = append(events, qwenEventRawToolCall{raw: before})
p.state = qwenParserState_LookingForToolStart
return events, true
} else {
// note that we don't need to check the overlap here because we only plan
// on parsing the tool call once we see the full closing tag. We don't
// stream back the unparsed tool content, so there's no need to be eager
// here
return events, false
}
default:
panic("unreachable")
}
}
// TODO(drifkin): move this to a shared location
// longest overlap between suffix of s and prefix of delim
func overlap(s, delim string) int {
max := min(len(delim), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
return i
}
}
return 0
}
func trailingWhitespaceLen(s string) int {
for i := len(s) - 1; i >= 0; i-- {
if !unicode.IsSpace(rune(s[i])) {
return len(s) - i - 1
}
}
return len(s)
}
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`
Parameters []XMLParameter `xml:"parameter"`
}
type XMLParameter struct {
Name string `xml:"name,attr"`
Value string `xml:",chardata"`
}
// parseToolCall parses a raw tool call string into an api.ToolCall.
// The raw string follows an xml-like format, here's an example:
//
// <function=get_current_temperature>
// <parameter=location>
// San Francisco
// </parameter>
// <parameter=unit>
// celsius
// </parameter>
// </function>
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
toolCall := api.ToolCall{}
xmlString := transformToXML(raw.raw)
var functionCall XMLFunctionCall
err := xml.Unmarshal([]byte(xmlString), &functionCall)
if err != nil {
return api.ToolCall{}, err
}
toolCall.Function = api.ToolCallFunction{
Name: functionCall.Name,
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionCall.Name {
matchedTool = &tools[i]
break
}
}
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
for _, parameter := range functionCall.Parameters {
// Look up the parameter type if we found the tool
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
paramType = prop.Type
}
}
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
}
return toolCall, nil
}
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
//
// For union types (multiple types in PropertyType, which we support but doesn't
// seem as though the reference parser does type coercion with those types in
// mind) we use a type precedence approach:
// 1. null - checked first regardless of declared types (matches reference implementation)
// 2. boolean - only "true"/"false" are valid booleans
// 3. integer - must parse as a whole number
// 4. number - must parse as numeric (returns int if no decimal part)
// 5. array - must parse as valid JSON array
// 6. object - must parse as valid JSON object
// 7. string - always succeeds (least specific type)
//
// This precedence ensures we return the most specific type that successfully parses,
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
func parseValue(raw string, paramType api.PropertyType) any {
// first remove a single leading newlines, and a single trailing newline (if
// they exist). This follows the reference implementation
raw = strings.TrimPrefix(raw, "\n")
raw = strings.TrimSuffix(raw, "\n")
// Check for null first (case-insensitive) - this takes precedence over any type
if strings.ToLower(raw) == "null" {
return nil
}
// If no type is specified, default to string
if len(paramType) == 0 {
return raw
}
// Check if any of the specified types match, using type precedence
// Order: boolean -> integer -> number -> array -> object -> string
typeSet := make(map[string]bool)
for _, t := range paramType {
typeSet[t] = true
}
// Try boolean first (most restrictive)
if typeSet["boolean"] {
lower := strings.ToLower(raw)
switch lower {
case "true":
return true
case "false":
return false
}
// If not a valid boolean but boolean is the only type, return false (matching reference)
if len(paramType) == 1 {
return false
}
// Otherwise try other types
}
// Try integer
if typeSet["integer"] {
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
// Return as int if it fits in int32, otherwise int64
if i >= math.MinInt32 && i <= math.MaxInt32 {
return int(i)
}
return i
}
// If integer is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try number (float)
if typeSet["number"] {
if f, err := strconv.ParseFloat(raw, 64); err == nil {
// If the number has no decimal part, return as int (matching reference)
if f == math.Trunc(f) {
i := int64(f)
if i >= math.MinInt32 && i <= math.MaxInt32 {
return int(i)
}
return i
}
return f
}
// If number is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try array
if typeSet["array"] {
var arr []any
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
return arr
}
// If array is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// Try object
if typeSet["object"] {
var obj map[string]any
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
return obj
}
// If object is the only type and parsing failed, fall back to string
if len(paramType) == 1 {
return raw
}
}
// String always succeeds (or if "string" is in the type set)
if typeSet["string"] {
return raw
}
// If we get here, none of the types matched and string wasn't an option
// We return string as a fallback. The reference implementation will attempt
// to parse the value as a python literal, but we purposefully don't support
// that
return raw
}
var (
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
)
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
// xml so that it can be parsed by any xml parser
func transformToXML(raw string) string {
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
// care to properly escape the string that becomes the attribute value
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1]
var escapedValue strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2]))
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
})
// Walk the resulting string, escaping any character data that sits between the
// xml tags we just emitted
var out strings.Builder
lastIdx := 0
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
if loc[0] > lastIdx {
escapeTextNode(&out, transformed[lastIdx:loc[0]])
}
out.WriteString(transformed[loc[0]:loc[1]])
lastIdx = loc[1]
}
if lastIdx < len(transformed) {
escapeTextNode(&out, transformed[lastIdx:])
}
return out.String()
}
// escapeTextNode escapes XML character data without altering other characters
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
func escapeTextNode(sb *strings.Builder, s string) {
for _, r := range s {
switch r {
case '&':
sb.WriteString("&amp;")
case '<':
sb.WriteString("&lt;")
case '>':
sb.WriteString("&gt;")
default:
sb.WriteRune(r)
}
}
}

View File

@@ -1,878 +0,0 @@
package parsers
import (
"reflect"
"testing"
"github.com/ollama/ollama/api"
)
// tool creates a test tool with the given name and properties
func tool(name string, props map[string]api.ToolProperty) api.Tool {
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
t.Function.Parameters.Type = "object"
t.Function.Parameters.Properties = props
return t
}
func TestQwenParserStreaming(t *testing.T) {
type step struct {
input string
wantEvents []qwenEvent
}
cases := []struct {
desc string
steps []step
only bool
}{
{
desc: "simple message streamed word by word",
steps: []step{
{
input: "hi",
wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
},
{
input: " there",
wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
},
},
},
{
desc: "content before tool call",
steps: []step{
{
input: "hi there<tool_call>",
wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
},
},
},
{
desc: "multiple tool calls in one message",
steps: []step{
{
input: "before1<tool_call>in tool call</tool_call>after1<tool_call>in tool call 2</tool_call>after2",
wantEvents: []qwenEvent{
qwenEventContent{content: "before1"},
qwenEventRawToolCall{raw: "in tool call"},
qwenEventContent{content: "after1"},
qwenEventRawToolCall{raw: "in tool call 2"},
qwenEventContent{content: "after2"},
},
},
},
},
{
desc: "tool calls with split tags",
steps: []step{
{
input: "before<tool",
wantEvents: []qwenEvent{
qwenEventContent{content: "before"},
},
},
{
input: "_call>in tool call</tool",
wantEvents: []qwenEvent{},
},
{
input: "_call>af",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "in tool call"},
qwenEventContent{content: "af"},
},
},
{
input: "ter",
wantEvents: []qwenEvent{
qwenEventContent{content: "ter"},
},
},
},
},
{
desc: "trailing whitespace between content and tool call",
steps: []step{
{
input: "abc\n<tool_call>def</tool_call>",
wantEvents: []qwenEvent{
qwenEventContent{content: "abc"},
qwenEventRawToolCall{raw: "def"},
},
},
},
},
{
desc: "trailing whitespace between tool call and content",
steps: []step{
{
input: "<tool_call>abc</tool_call>\ndef",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "abc"},
qwenEventContent{content: "def"},
},
},
},
},
{
desc: "empty content before tool call",
steps: []step{
{
input: "\n<tool_call>abc</tool_call>",
wantEvents: []qwenEvent{
qwenEventRawToolCall{raw: "abc"},
},
},
},
},
{
desc: "partial tool open tag fakeout",
steps: []step{
{
input: "abc\n<tool_call",
wantEvents: []qwenEvent{
// \n should not be emitted yet because `<tool_call` might be a tool
// open tag, in which case the whitespace should be trimmed
qwenEventContent{content: "abc"},
},
},
{
input: " fakeout",
wantEvents: []qwenEvent{
qwenEventContent{content: "\n<tool_call fakeout"},
},
},
},
},
{
desc: "token-by-token whitespace handling",
steps: []step{
{
input: "a",
wantEvents: []qwenEvent{
qwenEventContent{content: "a"},
},
},
{
input: "\n",
wantEvents: []qwenEvent{},
},
{
input: "b",
wantEvents: []qwenEvent{
qwenEventContent{content: "\nb"},
},
},
},
},
}
anyOnlies := false
for _, tc := range cases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range cases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3CoderParser{}
for i, step := range tc.steps {
parser.acc.WriteString(step.input)
gotEvents := parser.parseEvents()
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
// avoid deep equal on empty vs. nil slices
continue
}
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
}
}
})
}
}
func TestQwenToolParser(t *testing.T) {
type step struct {
name string
rawToolCall string
tools []api.Tool
wantToolCall api.ToolCall
}
steps := []step{
{
name: "simple tool call",
tools: []api.Tool{},
rawToolCall: `<function=get_current_temperature>
<parameter=location>
San Francisco
</parameter>
<parameter=unit>
celsius
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_temperature",
Arguments: map[string]any{
"location": "San Francisco",
"unit": "celsius",
},
},
},
},
{
name: "names with spaces",
tools: []api.Tool{},
rawToolCall: `<function=get current temperature>
<parameter=location with spaces>
San Francisco
</parameter>
<parameter=unit with spaces>
celsius
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get current temperature",
Arguments: map[string]any{
"location with spaces": "San Francisco",
"unit with spaces": "celsius",
},
},
},
},
// this mirrors the reference implementation's behavior, but unclear if it
// ever happens. If so, then we should probably remove them instead, this
// test is to just document the current behavior and test that we don't get
// xml errors
{
name: "names with quotes",
tools: []api.Tool{},
rawToolCall: `<function="get current temperature">
<parameter="location with spaces">
San Francisco
</parameter>
<parameter="unit with spaces">
"celsius"
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "\"get current temperature\"",
Arguments: map[string]any{
"\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"",
},
},
},
},
{
name: "tool call with typed parameters",
tools: []api.Tool{
tool("calculate", map[string]api.ToolProperty{
"x": {Type: api.PropertyType{"number"}},
"y": {Type: api.PropertyType{"integer"}},
"enabled": {Type: api.PropertyType{"boolean"}},
"items": {Type: api.PropertyType{"array"}},
}),
},
rawToolCall: `<function=calculate>
<parameter=x>
3.14
</parameter>
<parameter=y>
42
</parameter>
<parameter=enabled>
true
</parameter>
<parameter=items>
["a", "b", "c"]
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: map[string]any{
"x": 3.14,
"y": 42,
"enabled": true,
"items": []any{"a", "b", "c"},
},
},
},
},
// regression test for <https://github.com/ollama/ollama/issues/12357>
{
name: "ampersands in parameter values",
tools: []api.Tool{},
rawToolCall: `<function=exec>
<parameter=command>
ls && echo "done"
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: map[string]any{
"command": "ls && echo \"done\"",
},
},
},
},
{
name: "angle brackets in parameter values",
tools: []api.Tool{},
rawToolCall: `<function=exec>
<parameter=command>
ls && echo "a > b and a < b"
</parameter>
</function>`,
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"",
},
},
},
},
}
for i, step := range steps {
gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err)
}
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
}
}
}
func TestQwenToolCallValueParsing(t *testing.T) {
cases := []struct {
desc string
raw string
paramType api.PropertyType
want any
}{
{
desc: "default string value (no type specified)",
paramType: api.PropertyType{},
raw: "some-string",
want: "some-string",
},
{
desc: "trim a single leading and trailing newline",
paramType: api.PropertyType{},
raw: "\nsome-string\n",
want: "some-string",
},
{
desc: "trim at most one leading and trailing newline",
paramType: api.PropertyType{},
raw: "\n\nsome-string\n\n",
want: "\nsome-string\n",
},
{
desc: "newline really has to be the first character to be trimmed",
paramType: api.PropertyType{},
raw: " \nsome-string\n ",
want: " \nsome-string\n ",
},
{
desc: "numeric type",
paramType: api.PropertyType{"number"},
raw: "123",
want: 123,
},
// Integer parsing tests
{
desc: "integer type",
paramType: api.PropertyType{"integer"},
raw: "42",
want: 42,
},
{
desc: "negative integer",
paramType: api.PropertyType{"integer"},
raw: "-100",
want: -100,
},
{
desc: "zero integer",
paramType: api.PropertyType{"integer"},
raw: "0",
want: 0,
},
{
desc: "integer with leading zeros",
paramType: api.PropertyType{"integer"},
raw: "007",
want: 7,
},
{
desc: "large integer",
paramType: api.PropertyType{"integer"},
raw: "2147483648", // Just beyond int32 max
want: int64(2147483648),
},
// Float/number parsing tests
{
desc: "float type",
paramType: api.PropertyType{"number"},
raw: "3.14",
want: 3.14,
},
{
desc: "negative float",
paramType: api.PropertyType{"number"},
raw: "-273.15",
want: -273.15,
},
{
desc: "float without decimal part",
paramType: api.PropertyType{"number"},
raw: "100.0",
want: 100,
},
{
desc: "scientific notation positive",
paramType: api.PropertyType{"number"},
raw: "1.23e5",
want: 123000, // Will be int since it has no decimal part
},
{
desc: "scientific notation negative",
paramType: api.PropertyType{"number"},
raw: "1.5e-3",
want: 0.0015,
},
{
desc: "very small float",
paramType: api.PropertyType{"number"},
raw: "0.00000001",
want: 0.00000001,
},
// String parsing tests
{
desc: "explicit string type",
paramType: api.PropertyType{"string"},
raw: "hello world",
want: "hello world",
},
{
desc: "string with special characters",
paramType: api.PropertyType{"string"},
raw: "/usr/local/bin/test-file_v2.0.sh",
want: "/usr/local/bin/test-file_v2.0.sh",
},
{
desc: "string with quotes",
paramType: api.PropertyType{"string"},
raw: `He said "hello" to me`,
want: `He said "hello" to me`,
},
{
desc: "multiline string",
paramType: api.PropertyType{"string"},
raw: "line one\nline two\nline three",
want: "line one\nline two\nline three",
},
{
desc: "empty string",
paramType: api.PropertyType{"string"},
raw: "",
want: "",
},
{
desc: "string that looks like a number",
paramType: api.PropertyType{"string"},
raw: "12345",
want: "12345",
},
// Boolean parsing tests
{
desc: "boolean true",
paramType: api.PropertyType{"boolean"},
raw: "true",
want: true,
},
{
desc: "boolean false",
paramType: api.PropertyType{"boolean"},
raw: "false",
want: false,
},
{
desc: "boolean case insensitive true",
paramType: api.PropertyType{"boolean"},
raw: "True",
want: true,
},
{
desc: "boolean case insensitive false",
paramType: api.PropertyType{"boolean"},
raw: "FALSE",
want: false,
},
// Null parsing tests
{
desc: "null value lowercase",
paramType: api.PropertyType{"string"},
raw: "null",
want: nil,
},
{
desc: "null value case insensitive",
paramType: api.PropertyType{"integer"},
raw: "NULL",
want: nil,
},
// Array parsing tests
{
desc: "array of strings",
paramType: api.PropertyType{"array"},
raw: `["foo", "bar", "baz"]`,
want: []any{"foo", "bar", "baz"},
},
{
desc: "array of numbers",
paramType: api.PropertyType{"array"},
raw: `[1, 2.5, 3]`,
want: []any{float64(1), 2.5, float64(3)},
},
{
desc: "array of mixed types",
paramType: api.PropertyType{"array"},
raw: `["string", 123, true, null]`,
want: []any{"string", float64(123), true, nil},
},
{
desc: "empty array",
paramType: api.PropertyType{"array"},
raw: `[]`,
want: []any{},
},
// Object parsing tests
{
desc: "simple object",
paramType: api.PropertyType{"object"},
raw: `{"key": "value", "number": 42}`,
want: map[string]any{"key": "value", "number": float64(42)},
},
{
desc: "nested object",
paramType: api.PropertyType{"object"},
raw: `{"outer": {"inner": "value"}}`,
want: map[string]any{"outer": map[string]any{"inner": "value"}},
},
{
desc: "empty object",
paramType: api.PropertyType{"object"},
raw: `{}`,
want: map[string]any{},
},
// Error cases and fallback behavior
{
desc: "invalid integer falls back to string",
paramType: api.PropertyType{"integer"},
raw: "not-a-number",
want: "not-a-number",
},
{
desc: "invalid float falls back to string",
paramType: api.PropertyType{"number"},
raw: "3.14.159",
want: "3.14.159",
},
{
desc: "invalid boolean falls back to false",
paramType: api.PropertyType{"boolean"},
raw: "yes",
want: false,
},
{
desc: "invalid JSON array falls back to string",
paramType: api.PropertyType{"array"},
raw: "[1, 2, unclosed",
want: "[1, 2, unclosed",
},
{
desc: "invalid JSON object falls back to string",
paramType: api.PropertyType{"object"},
raw: `{"key": unclosed`,
want: `{"key": unclosed`,
},
// Edge cases
{
desc: "integer overflow should use int64",
paramType: api.PropertyType{"integer"},
raw: "2147483648", // Beyond int32 max
want: int64(2147483648),
},
{
desc: "float with many decimal places",
paramType: api.PropertyType{"number"},
raw: "3.141592653589793",
want: 3.141592653589793,
},
{
desc: "string with JSON-like content",
paramType: api.PropertyType{"string"},
raw: `{"this": "is", "just": "a string"}`,
want: `{"this": "is", "just": "a string"}`,
},
{
desc: "whitespace-only string",
paramType: api.PropertyType{"string"},
raw: " ",
want: " ",
},
// Unknown parameter (no type specified in tools)
{
desc: "parameter not in tool definition defaults to string",
paramType: api.PropertyType{},
raw: "some value",
want: "some value",
},
// Union type tests
{
desc: "string or number union - valid number",
paramType: api.PropertyType{"string", "number"},
raw: "42.5",
want: 42.5,
},
{
desc: "string or number union - non-numeric string",
paramType: api.PropertyType{"string", "number"},
raw: "hello",
want: "hello",
},
{
desc: "number or string union - valid number (order shouldn't matter)",
paramType: api.PropertyType{"number", "string"},
raw: "42.5",
want: 42.5,
},
{
desc: "integer or null union - valid integer",
paramType: api.PropertyType{"integer", "null"},
raw: "123",
want: 123,
},
{
desc: "integer or null union - null value",
paramType: api.PropertyType{"integer", "null"},
raw: "null",
want: nil,
},
{
desc: "null or integer union - null value (order shouldn't matter)",
paramType: api.PropertyType{"null", "integer"},
raw: "null",
want: nil,
},
{
desc: "boolean or string union - valid boolean",
paramType: api.PropertyType{"boolean", "string"},
raw: "true",
want: true,
},
{
desc: "boolean or string union - non-boolean becomes string",
paramType: api.PropertyType{"boolean", "string"},
raw: "yes",
want: "yes",
},
{
desc: "string or boolean union - valid boolean (precedence test)",
paramType: api.PropertyType{"string", "boolean"},
raw: "false",
want: false, // Should be boolean, not string "false"
},
{
desc: "integer or number union - integer value",
paramType: api.PropertyType{"integer", "number"},
raw: "42",
want: 42,
},
{
desc: "integer or number union - float value",
paramType: api.PropertyType{"integer", "number"},
raw: "42.5",
want: 42.5,
},
{
desc: "number or integer union - integer value (precedence test)",
paramType: api.PropertyType{"number", "integer"},
raw: "42",
want: 42, // Should try integer first due to precedence
},
{
desc: "array or object union - valid array",
paramType: api.PropertyType{"array", "object"},
raw: `[1, 2, 3]`,
want: []any{float64(1), float64(2), float64(3)},
},
{
desc: "array or object union - valid object",
paramType: api.PropertyType{"array", "object"},
raw: `{"key": "value"}`,
want: map[string]any{"key": "value"},
},
{
desc: "object or array union - valid array (precedence test)",
paramType: api.PropertyType{"object", "array"},
raw: `[1, 2, 3]`,
want: []any{float64(1), float64(2), float64(3)},
},
{
desc: "complex multi-type union - null",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "null",
want: nil,
},
{
desc: "complex multi-type union - boolean",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "true",
want: true,
},
{
desc: "complex multi-type union - number",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "3.14",
want: 3.14,
},
{
desc: "complex multi-type union - string",
paramType: api.PropertyType{"string", "number", "boolean", "null"},
raw: "hello",
want: "hello",
},
{
desc: "integer string union - integer string becomes integer",
paramType: api.PropertyType{"integer", "string"},
raw: "123",
want: 123,
},
{
desc: "string integer union - integer string becomes integer (precedence)",
paramType: api.PropertyType{"string", "integer"},
raw: "123",
want: 123, // Integer has higher precedence than string
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got := parseValue(tc.raw, tc.paramType)
if !reflect.DeepEqual(got, tc.want) {
t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
}
})
}
}
func TestQwenXMLTransform(t *testing.T) {
cases := []struct {
desc string
raw string
want string
}{
{
desc: "simple example",
raw: `<function=get_current_temperature>
<parameter=location>
San Francisco
</parameter>
<parameter=unit>
celsius
</parameter>
</function>`,
want: `<function name="get_current_temperature">
<parameter name="location">
San Francisco
</parameter>
<parameter name="unit">
celsius
</parameter>
</function>`,
},
// even though quotes aren't expected in these tags, we have these tests to
// make sure they're escaped so they don't blow up the xml parser in case
// they happen
{
desc: "names with quotes",
raw: `<function="get current temperature">
<parameter="location with spaces">
San Francisco
</parameter>
<parameter="unit with spaces">
celsius
</parameter>
</function>`,
want: `<function name="&#34;get current temperature&#34;">
<parameter name="&#34;location with spaces&#34;">
San Francisco
</parameter>
<parameter name="&#34;unit with spaces&#34;">
celsius
</parameter>
</function>`,
},
{
desc: "ampersands in parameter values",
raw: `<function=get_current_temperature>
<parameter=location>
San Francisco & San Jose
</parameter>
</function>`,
want: `<function name="get_current_temperature">
<parameter name="location">
San Francisco &amp; San Jose
</parameter>
</function>`,
},
}
for _, tc := range cases {
got := transformToXML(tc.raw)
if got != tc.want {
t.Errorf("got %q, want %q", got, tc.want)
}
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
cases := []struct {
desc string
s string
want int
}{
{desc: "no whitespace", s: "abc", want: 0},
{desc: "trailing whitespace", s: "abc ", want: 1},
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
{desc: "only whitespace", s: " \n ", want: 4},
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
}
for _, tc := range cases {
got := trailingWhitespaceLen(tc.s)
if got != tc.want {
t.Errorf("got %d, want %d", got, tc.want)
}
}
}

View File

@@ -1,217 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/ollama/ollama/api"
)
var (
imStartTag = "<|im_start|>"
imEndTag = "<|im_end|>"
)
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
// This follows the same approach from the reference implementation, which gives
// a particular key ordering
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
data, err := json.Marshal(obj)
if err != nil {
return ""
}
var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return ""
}
var sb strings.Builder
for key, value := range m {
if handledKeys[key] {
continue
}
// Check if value is a map or array (needs JSON serialization)
switch v := value.(type) {
case map[string]any, []any:
jsonBytes, _ := json.Marshal(v)
// TODO(drifkin): it would be nice to format the JSON here similarly to
// python's default json.dumps behavior (spaces after commas and colons).
// This would let us be byte-for-byte compatible with the reference
// implementation for most common inputs
jsonStr := string(jsonBytes)
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
case nil:
continue
default:
// Simple types, convert to string
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
}
}
return sb.String()
}
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
// filter out system messages and choose the first (if any) to win
var systemMessage string
var filteredMessages []api.Message
for _, message := range messages {
if message.Role != "system" {
filteredMessages = append(filteredMessages, message)
continue
}
if systemMessage == "" {
systemMessage = message.Content
}
}
if systemMessage != "" || len(tools) > 0 {
sb.WriteString(imStartTag + "system\n")
// if we have tools but no system message, match the reference implementation by providing a default system message
if systemMessage == "" {
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
}
sb.WriteString(systemMessage)
if len(tools) > 0 {
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
sb.WriteString("<tools>")
for _, tool := range tools {
sb.WriteString("\n")
sb.WriteString("<function>\n")
sb.WriteString("<name>" + tool.Function.Name + "</name>")
if tool.Function.Description != "" {
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
}
sb.WriteString("\n<parameters>")
for name, prop := range tool.Function.Parameters.Properties {
sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + name + "</name>")
if len(prop.Type) > 0 {
// TODO(!!!)(drifkin): we should match the reference implementation for
// more complex types here instead of using this format
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
}
if prop.Description != "" {
sb.WriteString("\n<description>" + prop.Description + "</description>")
}
// Render any additional keys not already handled
handledKeys := map[string]bool{
"type": true,
"description": true,
}
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
sb.WriteString("\n</parameter>")
}
// Render extra keys for parameters (everything except 'type' and 'properties')
paramHandledKeys := map[string]bool{
"type": true,
"properties": true,
}
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
sb.WriteString("\n</parameters>")
sb.WriteString("\n</function>")
}
sb.WriteString("\n</tools>")
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
}
sb.WriteString(imEndTag + "\n")
}
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
prefill := lastMessage && message.Role == "assistant"
switch message.Role {
case "assistant":
if len(message.ToolCalls) > 0 {
sb.WriteString(imStartTag + "assistant\n")
if message.Content != "" {
sb.WriteString(message.Content + "\n")
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
for name, value := range toolCall.Function.Arguments {
valueStr := formatToolCallArgument(value)
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
}
sb.WriteString("\n</function>\n</tool_call>")
}
sb.WriteString("<|im_end|>\n")
} else {
sb.WriteString(imStartTag + "assistant\n")
sb.WriteString(message.Content)
if !prefill {
sb.WriteString(imEndTag + "\n")
}
}
case "tool":
// consecutive tool responses should share a single `<im_start>user`, but
// have their own <tool_response> tags
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
sb.WriteString(imStartTag + "user\n")
}
sb.WriteString("<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>\n")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
sb.WriteString(imEndTag + "\n")
}
default:
sb.WriteString(imStartTag + message.Role + "\n")
sb.WriteString(message.Content)
sb.WriteString(imEndTag + "\n")
}
if lastMessage && !prefill {
sb.WriteString(imStartTag + "assistant\n")
}
}
return sb.String(), nil
}
func formatToolCallArgument(value any) string {
if value == nil {
return "null"
}
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
}
if reflect.TypeOf(value) != nil {
kind := reflect.TypeOf(value).Kind()
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
if marshalled, err := json.Marshal(value); err == nil {
return string(marshalled)
}
}
}
return fmt.Sprintf("%v", value)
}

View File

@@ -1,338 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestQwen3CoderRenderer(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "basic",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
`,
},
{
name: "with tools and response",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in San Francisco?"},
{
Role: "assistant",
Content: "I'll check the weather in San Francisco for you.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"unit": "fahrenheit",
},
},
},
},
},
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
{Role: "user", Content: "That sounds nice! What about New York?"},
},
tools: []api.Tool{
{Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Required: []string{"unit"},
Properties: map[string]api.ToolProperty{
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
// TODO(drifkin): add multiple params back once we have predictable
// order via some sort of ordered map type (see
// <https://github.com/ollama/ollama/issues/12244>)
/*
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
*/
},
},
}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
# Tools
You have access to the following functions:
<tools>
<function>
<name>get_weather</name>
<description>Get the current weather in a given location</description>
<parameters>
<parameter>
<name>unit</name>
<type>string</type>
<description>The unit of temperature</description>
<enum>["celsius","fahrenheit"]</enum>
</parameter>
<required>["unit"]</required>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
<|im_start|>user
What is the weather like in San Francisco?<|im_end|>
<|im_start|>assistant
I'll check the weather in San Francisco for you.
<tool_call>
<function=get_weather>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
</tool_response>
<|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
`,
},
{
name: "parallel tool calls",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "call double(1) and triple(2)"},
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
}},
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
},
tools: []api.Tool{
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
}}}},
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
}}}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
# Tools
You have access to the following functions:
<tools>
<function>
<name>double</name>
<description>Double a number</description>
<parameters>
<parameter>
<name>number</name>
<type>string</type>
<description>The number to double</description>
</parameter>
</parameters>
</function>
<function>
<name>triple</name>
<description>Triple a number</description>
<parameters>
<parameter>
<name>number</name>
<type>string</type>
<description>The number to triple</description>
</parameter>
</parameters>
</function>
</tools>
If you choose to call a function ONLY reply in the following format with NO suffix:
<tool_call>
<function=example_function_name>
<parameter=example_parameter_1>
value_1
</parameter>
<parameter=example_parameter_2>
This is the value for the second parameter
that can span
multiple lines
</parameter>
</function>
</tool_call>
<IMPORTANT>
Reminder:
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
- Required parameters MUST be specified
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
</IMPORTANT><|im_end|>
<|im_start|>user
call double(1) and triple(2)<|im_end|>
<|im_start|>assistant
I'll call double(1) and triple(2) for you.
<tool_call>
<function=double>
<parameter=number>
1
</parameter>
</function>
</tool_call>
<tool_call>
<function=triple>
<parameter=number>
2
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"number": 2}
</tool_response>
<tool_response>
{"number": 6}
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
{
name: "prefill",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Tell me something interesting."},
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
},
expected: `<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Tell me something interesting.<|im_end|>
<|im_start|>assistant
I'll tell you something interesting about cats`,
},
{
name: "complex tool call arguments should remain json encoded",
msgs: []api.Message{
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: map[string]any{
"payload": map[string]any{"foo": "bar"},
},
}},
}},
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
},
expected: `<|im_start|>user
call tool<|im_end|>
<|im_start|>assistant
<tool_call>
<function=echo>
<parameter=payload>
{"foo":"bar"}
</parameter>
</function>
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
{"payload": {"foo": "bar"}}
</tool_response>
<|im_end|>
<|im_start|>assistant
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestFormatToolCallArgument(t *testing.T) {
tests := []struct {
name string
arg any
expected string
}{
{
name: "string",
arg: "foo",
// notice no quotes around the string
expected: "foo",
},
{
name: "map",
arg: map[string]any{"foo": "bar"},
expected: "{\"foo\":\"bar\"}",
},
{
name: "number",
arg: 1,
expected: "1",
},
{
name: "boolean",
arg: true,
expected: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := formatToolCallArgument(tt.arg)
if got != tt.expected {
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
}
})
}
}

View File

@@ -1,26 +0,0 @@
package renderers
import (
"fmt"
"github.com/ollama/ollama/api"
)
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
renderer := rendererForName(name)
if renderer == nil {
return "", fmt.Errorf("unknown renderer %q", name)
}
return renderer(msgs, tools, think)
}
func rendererForName(name string) rendererFunc {
switch name {
case "qwen3-coder":
return Qwen3CoderRenderer
default:
return nil
}
}

View File

@@ -105,18 +105,16 @@ type ChatCompletionRequest struct {
Tools []api.Tool `json:"tools"` Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"`
DebugRenderOnly bool `json:"_debug_render_only"`
} }
type ChatCompletion struct { type ChatCompletion struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
Created int64 `json:"created"` Created int64 `json:"created"`
Model string `json:"model"` Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"` Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"` Usage Usage `json:"usage,omitempty"`
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
} }
type ChatCompletionChunk struct { type ChatCompletionChunk struct {
@@ -143,7 +141,6 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature"` Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP float32 `json:"top_p"`
Suffix string `json:"suffix"` Suffix string `json:"suffix"`
DebugRenderOnly bool `json:"_debug_render_only"`
} }
type Completion struct { type Completion struct {
@@ -276,8 +273,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
return nil return nil
}(r.DoneReason), }(r.DoneReason),
}}, Usage: toUsage(r), }},
DebugInfo: r.DebugInfo, Usage: toUsage(r),
} }
} }
@@ -571,14 +568,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
return &api.ChatRequest{ return &api.ChatRequest{
Model: r.Model, Model: r.Model,
Messages: messages, Messages: messages,
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools, Tools: r.Tools,
Think: think, Think: think,
DebugRenderOnly: r.DebugRenderOnly,
}, nil }, nil
} }
@@ -652,12 +648,11 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
} }
return api.GenerateRequest{ return api.GenerateRequest{
Model: r.Model, Model: r.Model,
Prompt: r.Prompt, Prompt: r.Prompt,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Suffix: r.Suffix, Suffix: r.Suffix,
DebugRenderOnly: r.DebugRenderOnly,
}, nil }, nil
} }

View File

@@ -62,14 +62,15 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
for _, c := range f.Commands { for _, c := range f.Commands {
switch c.Name { switch c.Name {
case "model": case "model":
path, err := expandPath(c.Args, relativeDir) name := c.Args.(string)
path, err := expandPath(name, relativeDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
digestMap, err := fileDigestMap(path) digestMap, err := fileDigestMap(path)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
req.From = c.Args req.From = name
continue continue
} else if err != nil { } else if err != nil {
return nil, err return nil, err
@@ -83,7 +84,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
} }
} }
case "adapter": case "adapter":
path, err := expandPath(c.Args, relativeDir) adapter := c.Args.(string)
path, err := expandPath(adapter, relativeDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -95,25 +97,25 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.Adapters = digestMap req.Adapters = digestMap
case "template": case "template":
req.Template = c.Args template := c.Args.(string)
req.Template = template
case "system": case "system":
req.System = c.Args system := c.Args.(string)
req.System = system
case "license": case "license":
licenses = append(licenses, c.Args) license := c.Args.(string)
case "renderer": licenses = append(licenses, license)
req.Renderer = c.Args
case "parser":
req.Parser = c.Args
case "message": case "message":
role, msg, _ := strings.Cut(c.Args, ": ") msg := c.Args.(*Message)
messages = append(messages, api.Message{Role: role, Content: msg}) messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
default: case "parameter":
if slices.Contains(deprecatedParameters, c.Name) { if slices.Contains(deprecatedParameters, c.Name) {
fmt.Printf("warning: parameter %s is deprecated\n", c.Name) fmt.Printf("warning: parameter '%s' is deprecated\n", c.Name)
break break
} }
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) param := c.Args.(*Parameter)
ps, err := api.FormatParams(map[string][]string{param.Name: {param.Value}})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,6 +129,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
params[k] = v params[k] = v
} }
} }
default:
return nil, fmt.Errorf("warning: unknown command '%s'", c.Name)
} }
} }
@@ -316,7 +320,17 @@ func filesForModel(path string) ([]string, error) {
type Command struct { type Command struct {
Name string Name string
Args string Args any
}
type Parameter struct {
Name string
Value string
}
type Message struct {
Role string
Content string
} }
func (c Command) String() string { func (c Command) String() string {
@@ -324,13 +338,17 @@ func (c Command) String() string {
switch c.Name { switch c.Name {
case "model": case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args) fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter", "renderer", "parser": case "license", "template", "system", "adapter":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) data := c.Args.(string)
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(data))
case "message": case "message":
role, message, _ := strings.Cut(c.Args, ": ") data := c.Args.(*Message)
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message)) fmt.Fprintf(&sb, "MESSAGE %s %s", data.Role, quote(data.Content))
case "parameter":
data := c.Args.(*Parameter)
fmt.Fprintf(&sb, "PARAMETER %s %s", data.Name, quote(data.Value))
default: default:
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args)) fmt.Printf("unknown command '%s'\n", c.Name)
} }
return sb.String() return sb.String()
@@ -350,7 +368,7 @@ const (
var ( var (
errMissingFrom = errors.New("no FROM line") errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"") errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
) )
type ParserError struct { type ParserError struct {
@@ -370,7 +388,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
var curr state var curr state
var currLine int = 1 var currLine int = 1
var b bytes.Buffer var b bytes.Buffer
var role string
var f Modelfile var f Modelfile
@@ -417,6 +434,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
case "parameter": case "parameter":
// transition to stateParameter which sets command name // transition to stateParameter which sets command name
next = stateParameter next = stateParameter
cmd.Name = s
case "message": case "message":
// transition to stateMessage which validates the message role // transition to stateMessage which validates the message role
next = stateMessage next = stateMessage
@@ -425,16 +443,37 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
cmd.Name = s cmd.Name = s
} }
case stateParameter: case stateParameter:
cmd.Name = b.String() s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
cmd.Args = &Parameter{
Name: s,
}
case stateMessage: case stateMessage:
if !isValidMessageRole(b.String()) { s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
if !isValidMessageRole(s) {
return nil, &ParserError{ return nil, &ParserError{
LineNumber: currLine, LineNumber: currLine,
Msg: errInvalidMessageRole.Error(), Msg: errInvalidMessageRole.Error(),
} }
} }
role = b.String() cmd.Args = &Message{
Role: s,
}
case stateComment, stateNil: case stateComment, stateNil:
// pass // pass
case stateValue: case stateValue:
@@ -447,12 +486,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
continue continue
} }
if role != "" { switch cmd.Name {
s = role + ": " + s case "parameter":
role = "" p := cmd.Args.(*Parameter)
p.Value = s
case "message":
m := cmd.Args.(*Message)
m.Content = s
default:
cmd.Args = s
} }
cmd.Args = s
f.Commands = append(f.Commands, cmd) f.Commands = append(f.Commands, cmd)
} }
@@ -477,11 +520,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
} }
if role != "" { switch cmd.Name {
s = role + ": " + s case "parameter":
c := cmd.Args.(*Parameter)
c.Value = s
case "message":
c := cmd.Args.(*Message)
c.Content = s
default:
cmd.Args = s
} }
cmd.Args = s
f.Commands = append(f.Commands, cmd) f.Commands = append(f.Commands, cmd)
default: default:
return nil, io.ErrUnexpectedEOF return nil, io.ErrUnexpectedEOF
@@ -610,7 +658,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool { func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) { switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message": case "from", "license", "template", "system", "adapter", "parameter", "message":
return true return true
default: default:
return false return false

View File

@@ -47,8 +47,8 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "model", Args: "model1"}, {Name: "model", Args: "model1"},
{Name: "adapter", Args: "adapter1"}, {Name: "adapter", Args: "adapter1"},
{Name: "license", Args: "MIT"}, {Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}},
{Name: "param2", Args: "value2"}, {Name: "parameter", Args: &Parameter{"param2", "value2"}},
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"}, {Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
} }
@@ -80,8 +80,8 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "model", Args: " model 1"}, {Name: "model", Args: " model 1"},
{Name: "adapter", Args: "adapter3"}, {Name: "adapter", Args: "adapter3"},
{Name: "license", Args: "MIT "}, {Name: "license", Args: "MIT "},
{Name: "param1", Args: "value1"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}},
{Name: "param2", Args: "value2"}, {Name: "parameter", Args: &Parameter{"param2", "value2"}},
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "}, {Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
} }
@@ -101,7 +101,7 @@ func TestParseFileFrom(t *testing.T) {
}, },
{ {
"FROM \"FOO BAR\"\nPARAMETER param1 value1", "FROM \"FOO BAR\"\nPARAMETER param1 value1",
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}}, []Command{{Name: "model", Args: "FOO BAR"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}}},
nil, nil,
}, },
{ {
@@ -149,12 +149,12 @@ func TestParseFileFrom(t *testing.T) {
}, },
{ {
"PARAMETER param1 value1\nFROM foo", "PARAMETER param1 value1\nFROM foo",
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}}, []Command{{Name: "parameter", Args: &Parameter{"param1", "value1"}}, {Name: "model", Args: "foo"}},
nil, nil,
}, },
{ {
"PARAMETER what the \nFROM lemons make lemonade ", "PARAMETER what the \nFROM lemons make lemonade ",
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}}, []Command{{Name: "parameter", Args: &Parameter{"what", "the"}}, {Name: "model", Args: "lemons make lemonade"}},
nil, nil,
}, },
} }
@@ -198,34 +198,6 @@ BADCOMMAND param1 value1
} }
} }
func TestParseFileRenderer(t *testing.T) {
input := `
FROM foo
RENDERER renderer1
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
}
func TestParseFileParser(t *testing.T) {
input := `
FROM foo
PARSER parser1
`
reader := strings.NewReader(input)
modelfile, err := ParseFile(reader)
require.NoError(t, err)
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
}
func TestParseFileMessages(t *testing.T) { func TestParseFileMessages(t *testing.T) {
cases := []struct { cases := []struct {
input string input string
@@ -239,7 +211,7 @@ MESSAGE system You are a file parser. Always parse things.
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
}, },
nil, nil,
}, },
@@ -249,7 +221,7 @@ FROM foo
MESSAGE system You are a file parser. Always parse things.`, MESSAGE system You are a file parser. Always parse things.`,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
}, },
nil, nil,
}, },
@@ -262,9 +234,9 @@ MESSAGE assistant Hello, I want to parse all the things!
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."}, {Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
{Name: "message", Args: "user: Hey there!"}, {Name: "message", Args: &Message{"user", "Hey there!"}},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"}, {Name: "message", Args: &Message{"assistant", "Hello, I want to parse all the things!"}},
}, },
nil, nil,
}, },
@@ -272,12 +244,12 @@ MESSAGE assistant Hello, I want to parse all the things!
` `
FROM foo FROM foo
MESSAGE system """ MESSAGE system """
You are a multiline file parser. Always parse things. You are a multiline file "parser". Always parse things.
""" """
`, `,
[]Command{ []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"}, {Name: "message", Args: &Message{"system", "\nYou are a multiline file \"parser\". Always parse things.\n"}},
}, },
nil, nil,
}, },
@@ -542,7 +514,7 @@ func TestParseFileParameters(t *testing.T) {
assert.Equal(t, []Command{ assert.Equal(t, []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
{Name: v.name, Args: v.value}, {Name: "parameter", Args: &Parameter{v.name, v.value}},
}, modelfile.Commands) }, modelfile.Commands)
}) })
} }
@@ -645,8 +617,8 @@ SYSTEM You are a utf16 file.
expected := []Command{ expected := []Command{
{Name: "model", Args: "bob"}, {Name: "model", Args: "bob"},
{Name: "param1", Args: "1"}, {Name: "parameter", Args: &Parameter{"param1", "1"}},
{Name: "param2", Args: "4096"}, {Name: "parameter", Args: &Parameter{"param2", "4096"}},
{Name: "system", Args: "You are a utf16 file."}, {Name: "system", Args: "You are a utf16 file."},
} }

View File

@@ -204,8 +204,13 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree = max(targetFree, 1) targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
return max(targetFree-currentFree, 0) if discard < 0 {
discard = 0
}
return discard
} }
type ErrReprocessInputs struct { type ErrReprocessInputs struct {

View File

@@ -242,8 +242,13 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
targetFree = max(targetFree, 1) targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
return max(targetFree-currentFree, 0) if discard < 0 {
discard = 0
}
return discard
} }
type ErrReprocessInputs struct { type ErrReprocessInputs struct {

View File

@@ -11,6 +11,7 @@ import (
"image" "image"
"log" "log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
@@ -31,7 +32,6 @@ import (
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input" "github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
@@ -405,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) { func (s *Server) run(ctx context.Context) {
s.ready.Wait() s.ready.Wait()
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
var activeBatch batchState var activeBatch batchState
for { for {
@@ -900,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone { if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return return
} }

View File

@@ -82,6 +82,7 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1) merges := make([]string, 0, 1)
// Only need vocab for Grammar Test // Only need vocab for Grammar Test
return model.NewBytePairEncoding( return model.NewBytePairEncoding(
``,
&model.Vocabulary{ &model.Vocabulary{
Values: tokens, Values: tokens,
Types: make([]int32, len(vocab)), Types: make([]int32, len(vocab)),

View File

@@ -16,7 +16,6 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
--build-arg=OLLAMA_FAST_BUILD \ --build-arg=OLLAMA_FAST_BUILD \
--build-arg=CUSTOM_CPU_FLAGS \ --build-arg=CUSTOM_CPU_FLAGS \
--build-arg=GPU_RUNNER_CPU_FLAGS \ --build-arg=GPU_RUNNER_CPU_FLAGS \
--build-arg=PARALLEL \
--build-arg=AMDGPU_TARGETS" --build-arg=AMDGPU_TARGETS"
echo "Building Ollama" echo "Building Ollama"

View File

@@ -10,11 +10,8 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"net"
"net/http" "net/http"
"net/url"
"os" "os"
"path"
"path/filepath" "path/filepath"
"slices" "slices"
"strings" "strings"
@@ -42,14 +39,6 @@ var (
) )
func (s *Server) CreateHandler(c *gin.Context) { func (s *Server) CreateHandler(c *gin.Context) {
config := &ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
var r api.CreateRequest var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -59,9 +48,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
return return
} }
config.Renderer = r.Renderer
config.Parser = r.Parser
for v := range r.Files { for v := range r.Files {
if !fs.ValidPath(v) { if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
@@ -91,34 +77,20 @@ func (s *Server) CreateHandler(c *gin.Context) {
oldManifest, _ := ParseNamedManifest(name) oldManifest, _ := ParseNamedManifest(name)
var baseLayers []*layerGGML var baseLayers []*layerGGML
var err error
var remote bool
if r.From != "" { if r.From != "" {
slog.Debug("create model from model name", "from", r.From) slog.Debug("create model from model name")
fromName := model.ParseName(r.From) fromName := model.ParseName(r.From)
if !fromName.IsValid() { if !fromName.IsValid() {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest} ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return return
} }
if r.RemoteHost != "" {
ru, err := remoteURL(r.RemoteHost)
if err != nil {
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
return
}
config.RemoteModel = r.From ctx, cancel := context.WithCancel(c.Request.Context())
config.RemoteHost = ru defer cancel()
remote = true
} else {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
baseLayers, err = parseFromModel(ctx, fromName, fn) baseLayers, err = parseFromModel(ctx, fromName, fn)
if err != nil { if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
}
} }
} else if r.Files != nil { } else if r.Files != nil {
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn) baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
@@ -138,7 +110,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
} }
var adapterLayers []*layerGGML var adapterLayers []*layerGGML
if !remote && r.Adapters != nil { if r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn) adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil { if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} { for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
@@ -156,56 +128,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
baseLayers = append(baseLayers, adapterLayers...) baseLayers = append(baseLayers, adapterLayers...)
} }
// Info is not currently exposed by Modelfiles, but allows overriding various if err := createModel(r, name, baseLayers, fn); err != nil {
// config values
if r.Info != nil {
caps, ok := r.Info["capabilities"]
if ok {
switch tcaps := caps.(type) {
case []any:
caps := make([]string, len(tcaps))
for i, c := range tcaps {
str, ok := c.(string)
if !ok {
continue
}
caps[i] = str
}
config.Capabilities = append(config.Capabilities, caps...)
}
}
strFromInfo := func(k string) string {
v, ok := r.Info[k]
if ok {
val := v.(string)
return val
}
return ""
}
vFromInfo := func(k string) float64 {
v, ok := r.Info[k]
if ok {
val := v.(float64)
return val
}
return 0
}
config.ModelFamily = strFromInfo("model_family")
if config.ModelFamily != "" {
config.ModelFamilies = []string{config.ModelFamily}
}
config.BaseName = strFromInfo("base_name")
config.FileType = strFromInfo("quantization_level")
config.ModelType = strFromInfo("parameter_size")
config.ContextLen = int(vFromInfo("context_length"))
config.EmbedLen = int(vFromInfo("embedding_length"))
}
if err := createModel(r, name, baseLayers, config, fn); err != nil {
if errors.Is(err, errBadTemplate) { if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return return
@@ -231,51 +154,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func remoteURL(raw string) (string, error) {
// Specialcase: user supplied only a path ("/foo/bar").
if strings.HasPrefix(raw, "/") {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("localhost", "11434"),
Path: path.Clean(raw),
}).String(), nil
}
if !strings.Contains(raw, "://") {
raw = "http://" + raw
}
if raw == "ollama.com" || raw == "http://ollama.com" {
raw = "https://ollama.com:443"
}
u, err := url.Parse(raw)
if err != nil {
return "", fmt.Errorf("parse error: %w", err)
}
if u.Host == "" {
u.Host = "localhost"
}
hostPart, portPart, err := net.SplitHostPort(u.Host)
if err == nil {
u.Host = net.JoinHostPort(hostPart, portPart)
} else {
u.Host = net.JoinHostPort(u.Host, "11434")
}
if u.Path != "" {
u.Path = path.Clean(u.Path)
}
if u.Path == "/" {
u.Path = ""
}
return u.String(), nil
}
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
switch detectModelTypeFromFiles(files) { switch detectModelTypeFromFiles(files) {
case "safetensors": case "safetensors":
@@ -438,7 +316,15 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
return ggml.KV{}, fmt.Errorf("no base model was found") return ggml.KV{}, fmt.Errorf("no base model was found")
} }
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
var layers []Layer var layers []Layer
for _, layer := range baseLayers { for _, layer := range baseLayers {
if layer.GGML != nil { if layer.GGML != nil {
@@ -518,7 +404,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
return err return err
} }
configLayer, err := createConfigLayer(layers, *config) configLayer, err := createConfigLayer(layers, config)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -104,154 +104,3 @@ func TestConvertFromSafetensors(t *testing.T) {
}) })
} }
} }
func TestRemoteURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
hasError bool
}{
{
name: "absolute path",
input: "/foo/bar",
expected: "http://localhost:11434/foo/bar",
hasError: false,
},
{
name: "absolute path with cleanup",
input: "/foo/../bar",
expected: "http://localhost:11434/bar",
hasError: false,
},
{
name: "root path",
input: "/",
expected: "http://localhost:11434/",
hasError: false,
},
{
name: "host without scheme",
input: "example.com",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "host with port",
input: "example.com:8080",
expected: "http://example.com:8080",
hasError: false,
},
{
name: "full URL",
input: "https://example.com:8080/path",
expected: "https://example.com:8080/path",
hasError: false,
},
{
name: "full URL with path cleanup",
input: "https://example.com:8080/path/../other",
expected: "https://example.com:8080/other",
hasError: false,
},
{
name: "ollama.com special case",
input: "ollama.com",
expected: "https://ollama.com:443",
hasError: false,
},
{
name: "http ollama.com special case",
input: "http://ollama.com",
expected: "https://ollama.com:443",
hasError: false,
},
{
name: "URL with only host",
input: "http://example.com",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "URL with root path cleaned",
input: "http://example.com/",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "invalid URL",
input: "http://[::1]:namedport", // invalid port
expected: "",
hasError: true,
},
{
name: "empty string",
input: "",
expected: "http://localhost:11434",
hasError: false,
},
{
name: "host with scheme but no port",
input: "http://localhost",
expected: "http://localhost:11434",
hasError: false,
},
{
name: "complex path cleanup",
input: "/a/b/../../c/./d",
expected: "http://localhost:11434/c/d",
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := remoteURL(tt.input)
if tt.hasError {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestRemoteURL_Idempotent(t *testing.T) {
// Test that applying remoteURL twice gives the same result as applying it once
testInputs := []string{
"/foo/bar",
"example.com",
"https://example.com:8080/path",
"ollama.com",
"http://localhost:11434",
}
for _, input := range testInputs {
t.Run(input, func(t *testing.T) {
firstResult, err := remoteURL(input)
if err != nil {
t.Fatalf("first call failed: %v", err)
}
secondResult, err := remoteURL(firstResult)
if err != nil {
t.Fatalf("second call failed: %v", err)
}
if firstResult != secondResult {
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
}
})
}
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf" "github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking" "github.com/ollama/ollama/thinking"
@@ -74,38 +73,29 @@ func (m *Model) Capabilities() []model.Capability {
capabilities := []model.Capability{} capabilities := []model.Capability{}
// Check for completion capability // Check for completion capability
if m.ModelPath != "" { f, err := gguf.Open(m.ModelPath)
f, err := gguf.Open(m.ModelPath) if err == nil {
if err == nil { defer f.Close()
defer f.Close()
if f.KeyValue("pooling_type").Valid() { if f.KeyValue("pooling_type").Valid() {
capabilities = append(capabilities, model.CapabilityEmbedding) capabilities = append(capabilities, model.CapabilityEmbedding)
} else {
// If no embedding is specified, we assume the model supports completion
capabilities = append(capabilities, model.CapabilityCompletion)
}
if f.KeyValue("vision.block_count").Valid() {
capabilities = append(capabilities, model.CapabilityVision)
}
} else { } else {
slog.Error("couldn't open model file", "error", err) // If no embedding is specified, we assume the model supports completion
capabilities = append(capabilities, model.CapabilityCompletion)
} }
} else if len(m.Config.Capabilities) > 0 { if f.KeyValue("vision.block_count").Valid() {
for _, c := range m.Config.Capabilities { capabilities = append(capabilities, model.CapabilityVision)
capabilities = append(capabilities, model.Capability(c))
} }
} else { } else {
slog.Warn("unknown capabilities for model", "model", m.Name) slog.Error("couldn't open model file", "error", err)
} }
if m.Template == nil { if m.Template == nil {
return capabilities return capabilities
} }
builtinParser := parsers.ParserForName(m.Config.Parser)
// Check for tools capability // Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { if slices.Contains(m.Template.Vars(), "tools") {
capabilities = append(capabilities, model.CapabilityTools) capabilities = append(capabilities, model.CapabilityTools)
} }
@@ -119,16 +109,10 @@ func (m *Model) Capabilities() []model.Capability {
capabilities = append(capabilities, model.CapabilityVision) capabilities = append(capabilities, model.CapabilityVision)
} }
// Skip the thinking check if it's already set
if slices.Contains(capabilities, "thinking") {
return capabilities
}
// Check for thinking capability // Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template) openingTag, closingTag := thinking.InferTags(m.Template.Template)
hasTags := openingTag != "" && closingTag != "" hasTags := openingTag != "" && closingTag != ""
isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) {
if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
capabilities = append(capabilities, model.CapabilityThinking) capabilities = append(capabilities, model.CapabilityThinking)
} }
@@ -214,20 +198,6 @@ func (m *Model) String() string {
}) })
} }
if m.Config.Renderer != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "renderer",
Args: m.Config.Renderer,
})
}
if m.Config.Parser != "" {
modelfile.Commands = append(modelfile.Commands, parser.Command{
Name: "parser",
Args: m.Config.Parser,
})
}
for k, v := range m.Options { for k, v := range m.Options {
switch v := v.(type) { switch v := v.(type) {
case []any: case []any:
@@ -266,19 +236,8 @@ type ConfigV2 struct {
ModelFormat string `json:"model_format"` ModelFormat string `json:"model_format"`
ModelFamily string `json:"model_family"` ModelFamily string `json:"model_family"`
ModelFamilies []string `json:"model_families"` ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"` // shown as Parameter Size ModelType string `json:"model_type"`
FileType string `json:"file_type"` // shown as Quantization Level FileType string `json:"file_type"`
Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"`
RemoteHost string `json:"remote_host,omitempty"`
RemoteModel string `json:"remote_model,omitempty"`
// used for remotes
Capabilities []string `json:"capabilities,omitempty"`
ContextLen int `json:"context_length,omitempty"`
EmbedLen int `json:"embedding_length,omitempty"`
BaseName string `json:"base_name,omitempty"`
// required by spec // required by spec
Architecture string `json:"architecture"` Architecture string `json:"architecture"`

View File

@@ -25,7 +25,10 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
// n^2 backoff timer is a little smoother than the // n^2 backoff timer is a little smoother than the
// common choice of 2^n. // common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order // Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems. // to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5)) d = time.Duration(float64(d) * (rand.Float64() + 0.5))

View File

@@ -11,7 +11,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
@@ -42,12 +41,18 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
} }
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) thinkVal := false
if err != nil { thinkLevel := ""
if think != nil {
thinkVal = think.Bool()
thinkLevel = think.String()
}
var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
return "", nil, err return "", nil, err
} }
s, err := tokenize(ctx, p) s, err := tokenize(ctx, b.String())
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
@@ -96,23 +101,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
} }
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
if err != nil {
return "", nil, err
}
return p, images, nil
}
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" {
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
if err != nil {
return "", err
}
return rendered, nil
}
var b bytes.Buffer var b bytes.Buffer
thinkVal := false thinkVal := false
thinkLevel := "" thinkLevel := ""
@@ -120,8 +108,9 @@ func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.Thi
thinkVal = think.Bool() thinkVal = think.Bool()
thinkLevel = think.String() thinkLevel = think.String()
} }
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
return "", err return "", nil, err
} }
return b.String(), nil
return b.String(), images, nil
} }

View File

@@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -16,7 +15,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"net/url"
"os" "os"
"os/signal" "os/signal"
"slices" "slices"
@@ -30,14 +28,13 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover" "github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil" "github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/server/internal/registry"
@@ -49,8 +46,6 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
func shouldUseHarmony(model *Model) bool { func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony: // heuristic to check whether the template expects to be parsed via harmony:
@@ -153,17 +148,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil return runner.llama, model, &opts, nil
} }
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
return "", err
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
}
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
@@ -204,90 +188,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
return
}
req.Model = m.Config.RemoteModel
if req.Template == "" && m.Template.String() != "" {
req.Template = m.Template.String()
}
if req.Options == nil {
req.Options = map[string]any{}
}
for k, v := range m.Options {
if _, ok := req.Options[k]; !ok {
req.Options[k] = v
}
}
// update the system prompt from the model if one isn't already specified
if req.System == "" && m.System != "" {
req.System = m.System
}
if len(m.Messages) > 0 {
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
}
fn := func(resp api.GenerateResponse) error {
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
data, err := json.Marshal(resp)
if err != nil {
return err
}
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
return err
}
c.Writer.Flush()
return nil
}
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Generate(c, &req, fn)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
// expire the runner // expire the runner
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 { if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m) s.sched.expireRunner(m)
@@ -307,21 +207,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
var builtinParser parsers.Parser useHarmony := shouldUseHarmony(m) && !req.Raw
if shouldUseHarmony(m) && m.Config.Parser == "" { var harmonyMessageHandler *harmony.HarmonyMessageHandler
m.Config.Parser = "harmony" var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
} }
if !req.Raw && m.Config.Parser != "" { // Validate Think value: string values currently only allowed for gptoss models
builtinParser = parsers.ParserForName(m.Config.Parser) if req.Think != nil && req.Think.IsString() && !useHarmony {
if builtinParser != nil {
// no tools or last message for generate endpoint
builtinParser.Init(nil, nil)
}
}
// Validate Think value: string values currently only allowed for harmony/gptoss models
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
return return
} }
@@ -433,10 +329,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model // If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly { if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.DebugTemplateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
DebugInfo: &api.DebugInfo{ DebugInfo: api.DebugInfo{
RenderedTemplate: prompt, RenderedTemplate: prompt,
ImageCount: len(images), ImageCount: len(images),
}, },
@@ -445,16 +341,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
var thinkingState *thinking.Parser var thinkingState *thinking.Parser
if builtinParser == nil { if !useHarmony {
openingTag, closingTag := thinking.InferTags(m.Template.Template) openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" { if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
thinkingState = &thinking.Parser{ thinkingState = &thinking.Parser{
OpeningTag: openingTag, OpeningTag: openingTag,
ClosingTag: closingTag, ClosingTag: closingTag,
} }
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
thinkingState.AddContent(openingTag)
}
} }
} }
@@ -482,17 +375,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}, },
} }
if builtinParser != nil { if useHarmony {
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done) content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Response = content res.Response = content
res.Thinking = thinking res.Thinking = thinking
if cr.Done && len(toolCalls) > 0 { harmonyToolParser.Add(toolContent)
res.ToolCalls = toolCalls
}
} else if thinkingState != nil { } else if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content) thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking res.Thinking = thinking
@@ -504,6 +391,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
if cr.Done { if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String() res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -518,7 +425,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
} }
if builtinParser != nil { if useHarmony {
// only send messages with meaningful content (empty messages confuse clients) // only send messages with meaningful content (empty messages confuse clients)
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
ch <- res ch <- res
@@ -643,7 +550,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen { if len(tokens) > ctxLen {
if !truncate { if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"}) c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return return
} }
@@ -655,13 +562,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen-- ctxLen--
} }
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
if ctxLen <= 0 {
// return error if the truncated input would be empty or just special tokens
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen] tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens) s, err = r.Detokenize(c.Request.Context(), tokens)
@@ -1030,28 +930,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
ModifiedAt: manifest.fi.ModTime(), ModifiedAt: manifest.fi.ModTime(),
} }
if m.Config.RemoteHost != "" {
resp.RemoteHost = m.Config.RemoteHost
resp.RemoteModel = m.Config.RemoteModel
if m.Config.ModelFamily != "" {
resp.ModelInfo = make(map[string]any)
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
if m.Config.BaseName != "" {
resp.ModelInfo["general.basename"] = m.Config.BaseName
}
if m.Config.ContextLen > 0 {
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
}
if m.Config.EmbedLen > 0 {
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
}
}
}
var params []string var params []string
cs := 30 cs := 30
for k, v := range m.Options { for k, v := range m.Options {
@@ -1082,11 +960,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
fmt.Fprint(&sb, m.String()) fmt.Fprint(&sb, m.String())
resp.Modelfile = sb.String() resp.Modelfile = sb.String()
// skip loading tensor information if this is a remote model
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
return resp, nil
}
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1163,13 +1036,11 @@ func (s *Server) ListHandler(c *gin.Context) {
// tag should never be masked // tag should never be masked
models = append(models, api.ListModelResponse{ models = append(models, api.ListModelResponse{
Model: n.DisplayShortest(), Model: n.DisplayShortest(),
Name: n.DisplayShortest(), Name: n.DisplayShortest(),
RemoteModel: cf.RemoteModel, Size: m.Size(),
RemoteHost: cf.RemoteHost, Digest: m.digest,
Size: m.Size(), ModifiedAt: m.fi.ModTime(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{ Details: api.ModelDetails{
Format: cf.ModelFormat, Format: cf.ModelFormat,
Family: cf.ModelFamily, Family: cf.ModelFamily,
@@ -1429,12 +1300,6 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/show", s.ShowHandler) r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler) r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/me", s.WhoamiHandler)
r.POST("/api/signout", s.SignoutHandler)
// deprecated
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
// Create // Create
r.POST("/api/create", s.CreateHandler) r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler)
@@ -1631,70 +1496,6 @@ func streamResponse(c *gin.Context, ch chan any) {
}) })
} }
func (s *Server) WhoamiHandler(c *gin.Context) {
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
if err != nil {
slog.Error(err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
return
}
client := api.NewClient(u, http.DefaultClient)
user, err := client.Whoami(c)
if err != nil {
slog.Error(err.Error())
}
// user isn't signed in
if user != nil && user.Name == "" {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
c.JSON(http.StatusOK, user)
}
func (s *Server) SignoutHandler(c *gin.Context) {
pubKey, err := auth.GetPublicKey()
if err != nil {
slog.Error("couldn't get public key", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
if err != nil {
slog.Error(err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
return
}
client := api.NewClient(u, http.DefaultClient)
err = client.Disconnect(c, encKey)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
c.JSON(http.StatusOK, nil)
}
func (s *Server) PsHandler(c *gin.Context) { func (s *Server) PsHandler(c *gin.Context) {
models := []api.ProcessModelResponse{} models := []api.ProcessModelResponse{}
@@ -1751,34 +1552,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
m, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// expire the runner // expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m) model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.ChatResponse{ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model, Model: req.Model,
@@ -1790,83 +1578,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
return
}
req.Model = m.Config.RemoteModel
if req.Options == nil {
req.Options = map[string]any{}
}
msgs := append(m.Messages, req.Messages...)
if req.Messages[0].Role != "system" && m.System != "" {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
msgs = filterThinkTags(msgs, m)
req.Messages = msgs
for k, v := range m.Options {
if _, ok := req.Options[k]; !ok {
req.Options[k] = v
}
}
fn := func(resp api.ChatResponse) error {
resp.Model = origModel
resp.RemoteModel = m.Config.RemoteModel
resp.RemoteHost = m.Config.RemoteHost
data, err := json.Marshal(resp)
if err != nil {
return err
}
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
return err
}
c.Writer.Flush()
return nil
}
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Chat(c, &req, fn)
if err != nil {
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
caps := []model.Capability{model.CapabilityCompletion} caps := []model.Capability{model.CapabilityCompletion}
if len(req.Tools) > 0 { if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools) caps = append(caps, model.CapabilityTools)
@@ -1875,6 +1586,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, model.CapabilityThinking) caps = append(caps, model.CapabilityThinking)
} }
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1903,23 +1625,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
msgs = filterThinkTags(msgs, m) msgs = filterThinkTags(msgs, m)
if shouldUseHarmony(m) && m.Config.Parser == "" { var harmonyMessageHandler *harmony.HarmonyMessageHandler
m.Config.Parser = "harmony" var harmonyToolParser *harmony.HarmonyToolCallAccumulator
}
useHarmony := shouldUseHarmony(m)
var builtinParser parsers.Parser
processedTools := req.Tools processedTools := req.Tools
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
if m.Config.Parser != "" { // make a copy of tools to pass to the chat prompt. Function names may be
builtinParser = parsers.ParserForName(m.Config.Parser) // renamed to be valid Harmony function names.
if builtinParser != nil { processedTools = make([]api.Tool, len(req.Tools))
// Determine last message for chat prefill copy(processedTools, req.Tools)
var lastMessage *api.Message for i, tool := range processedTools {
if len(msgs) > 0 { processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
lastMessage = &msgs[len(msgs)-1]
}
// Initialize parser and get processed tools
processedTools = builtinParser.Init(req.Tools, lastMessage)
} }
} }
@@ -1932,10 +1658,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model // If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly { if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.ChatResponse{ c.JSON(http.StatusOK, api.DebugTemplateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
DebugInfo: &api.DebugInfo{ DebugInfo: api.DebugInfo{
RenderedTemplate: prompt, RenderedTemplate: prompt,
ImageCount: len(images), ImageCount: len(images),
}, },
@@ -1943,8 +1669,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
// Validate Think value: string values currently only allowed for harmony/gptoss models // Validate Think value: string values currently only allowed for gptoss models
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { if req.Think != nil && req.Think.IsString() && !useHarmony {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
return return
} }
@@ -1963,7 +1689,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
var toolParser *tools.Parser var toolParser *tools.Parser
if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) { if len(req.Tools) > 0 && !useHarmony {
toolParser = tools.NewParser(m.Template.Template, req.Tools) toolParser = tools.NewParser(m.Template.Template, req.Tools)
} }
@@ -1995,24 +1721,30 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
if builtinParser != nil { if useHarmony {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Message.Content = content res.Message.Content = content
res.Message.Thinking = thinking res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls harmonyToolParser.Add(toolContent)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { if r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
} }
return return

View File

@@ -11,7 +11,6 @@ import (
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"slices" "slices"
"strings" "strings"
"testing" "testing"
@@ -21,7 +20,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/types/model"
) )
var stream bool = false var stream bool = false
@@ -617,78 +615,6 @@ func TestCreateTemplateSystem(t *testing.T) {
}) })
} }
func TestCreateAndShowRemoteModel(t *testing.T) {
gin.SetMode(gin.TestMode)
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test",
From: "bob",
RemoteHost: "https://ollama.com",
Info: map[string]any{
"capabilities": []string{"completion", "tools", "thinking"},
"model_family": "gptoss",
"context_length": 131072,
"embedding_length": 2880,
"quantization_level": "MXFP4",
"parameter_size": "20.9B",
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("exected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"})
if w.Code != http.StatusOK {
t.Fatalf("exected status code 200, actual %d", w.Code)
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
expectedDetails := api.ModelDetails{
ParentModel: "",
Format: "",
Family: "gptoss",
Families: []string{"gptoss"},
ParameterSize: "20.9B",
QuantizationLevel: "MXFP4",
}
if !reflect.DeepEqual(resp.Details, expectedDetails) {
t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details)
}
expectedCaps := []model.Capability{
model.Capability("completion"),
model.Capability("tools"),
model.Capability("thinking"),
}
if !slices.Equal(resp.Capabilities, expectedCaps) {
t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities)
}
v, ok := resp.ModelInfo["gptoss.context_length"]
ctxlen := v.(float64)
if !ok || int(ctxlen) != 131072 {
t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen))
}
v, ok = resp.ModelInfo["gptoss.embedding_length"]
embedlen := v.(float64)
if !ok || int(embedlen) != 2880 {
t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen))
}
fmt.Printf("resp = %#v\n", resp)
}
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)

View File

@@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
} }
var response api.GenerateResponse var response api.DebugTemplateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err) t.Fatalf("failed to unmarshal response: %v", err)
} }
@@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
} }
var response api.ChatResponse var response api.DebugTemplateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err) t.Fatalf("failed to unmarshal response: %v", err)
} }

View File

@@ -126,15 +126,7 @@ func TestRoutes(t *testing.T) {
t.Fatalf("failed to create model: %v", err) t.Fatalf("failed to create model: %v", err)
} }
config := &ConfigV2{ if err := createModel(r, modelName, baseLayers, fn); err != nil {
OS: "linux",
Architecture: "amd64",
RootFS: RootFS{
Type: "layers",
},
}
if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
t.Fatal(err) t.Fatal(err)
} }
} }

View File

@@ -382,7 +382,10 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs // load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit. // (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool { func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
numParallel := max(int(envconfig.NumParallel()), 1) numParallel := int(envconfig.NumParallel())
if numParallel < 1 {
numParallel = 1
}
// Embedding models should always be loaded with parallel=1 // Embedding models should always be loaded with parallel=1
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil { if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {