Compare commits
52 Commits
v0.11.11-r
...
nicole/web
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f944382424 | ||
|
|
3aa34ff0e6 | ||
|
|
0797490d9a | ||
|
|
c47154c08d | ||
|
|
b04e46da3e | ||
|
|
34efbbd3f0 | ||
|
|
05ba4ca1f4 | ||
|
|
5a56ff3cf0 | ||
|
|
2fba04b5fb | ||
|
|
fbd82ba5bb | ||
|
|
2e742544bf | ||
|
|
bbb195a6ff | ||
|
|
fd88cd7cb0 | ||
|
|
e1979c571a | ||
|
|
bf78ed6ee9 | ||
|
|
a40d427bce | ||
|
|
64883e3c4c | ||
|
|
41efdd4048 | ||
|
|
c23e6f4cae | ||
|
|
af060eb250 | ||
|
|
ae5c33008e | ||
|
|
3677842ff1 | ||
|
|
242df70a75 | ||
|
|
dba39b2eee | ||
|
|
9f3a37fd36 | ||
|
|
7460259eb3 | ||
|
|
22ccdd74c2 | ||
|
|
0c3d0e7533 | ||
|
|
e7f56ef3d8 | ||
|
|
eb0a5d4459 | ||
|
|
ceac416ec2 | ||
|
|
2717dce6fe | ||
|
|
9b8187b487 | ||
|
|
8b894933a7 | ||
|
|
9c5bf342bc | ||
|
|
564b558c92 | ||
|
|
a417ac97ee | ||
|
|
05d53457af | ||
|
|
b225508c9b | ||
|
|
fa1c987a29 | ||
|
|
ad95d5b30b | ||
|
|
c253433d68 | ||
|
|
a1cff89b30 | ||
|
|
93c64ea1b1 | ||
|
|
3f6642f6fc | ||
|
|
6f7117145f | ||
|
|
472feec2ff | ||
|
|
47991940d4 | ||
|
|
92b96d54ef | ||
|
|
9d56e63dbf | ||
|
|
053092185e | ||
|
|
44a6792873 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@
|
||||
dist
|
||||
build
|
||||
.cache
|
||||
.gocache
|
||||
*.exe
|
||||
.idea
|
||||
test_data
|
||||
|
||||
@@ -98,10 +98,12 @@ check_language(HIP)
|
||||
if(CMAKE_HIP_COMPILER)
|
||||
set(HIP_PLATFORM "amd")
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
if(NOT AMDGPU_TARGETS)
|
||||
find_package(hip REQUIRED)
|
||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
|
||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
endif()
|
||||
|
||||
if(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||
endif()
|
||||
|
||||
|
||||
36
Dockerfile
36
Dockerfile
@@ -1,6 +1,7 @@
|
||||
# vim: filetype=dockerfile
|
||||
|
||||
ARG FLAVOR=${TARGETARCH}
|
||||
ARG PARALLEL=8
|
||||
|
||||
ARG ROCMVERSION=6.3.3
|
||||
ARG JETPACK5VERSION=r35.4.1
|
||||
@@ -34,46 +35,51 @@ ENV LDFLAGS=-s
|
||||
FROM base AS cpu
|
||||
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
|
||||
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CPU' \
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
|
||||
FROM base AS cuda-13
|
||||
ARG CUDA13VERSION=13.0
|
||||
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||
&& cmake --build --parallel --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
|
||||
FROM base AS rocm-6
|
||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'ROCm 6' \
|
||||
&& cmake --build --parallel --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
|
||||
&& cmake --install build --component HIP --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK5VERSION} AS jetpack-5
|
||||
ARG CMAKEVERSION
|
||||
@@ -81,10 +87,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 5' \
|
||||
&& cmake --build --parallel --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 5' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM --platform=linux/arm64 nvcr.io/nvidia/l4t-jetpack:${JETPACK6VERSION} AS jetpack-6
|
||||
ARG CMAKEVERSION
|
||||
@@ -92,10 +99,11 @@ RUN apt-get update && apt-get install -y curl ccache \
|
||||
&& curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
|
||||
COPY CMakeLists.txt CMakePresets.json .
|
||||
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
|
||||
ARG PARALLEL
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'JetPack 6' \
|
||||
&& cmake --build --parallel --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
&& cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \
|
||||
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
|
||||
|
||||
FROM base AS build
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
|
||||
@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
authError := AuthorizationError{StatusCode: resp.StatusCode}
|
||||
json.Unmarshal(body, &authError)
|
||||
return authError
|
||||
}
|
||||
|
||||
apiError := StatusError{StatusCode: resp.StatusCode}
|
||||
|
||||
err := json.Unmarshal(body, &apiError)
|
||||
@@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
SigninURL string `json:"signin_url,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
@@ -222,7 +229,13 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
return AuthorizationError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
SigninURL: errorResponse.SigninURL,
|
||||
}
|
||||
} else if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
@@ -357,6 +370,24 @@ func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||
return &lr, nil
|
||||
}
|
||||
|
||||
// WebSearch performs a web search using the ollama service.
|
||||
func (c *Client) WebSearch(ctx context.Context, req *SearchRequest) (*SearchResponse, error) {
|
||||
var sr SearchResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/web_search", req, &sr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sr, nil
|
||||
}
|
||||
|
||||
// Fetch retrieves content from a URL using the ollama service.
|
||||
func (c *Client) Fetch(ctx context.Context, req *FetchRequest) (*FetchResponse, error) {
|
||||
var fr FetchResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/web_fetch", req, &fr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fr, nil
|
||||
}
|
||||
|
||||
// Copy copies a model - creating a model with another name from an existing
|
||||
// model.
|
||||
func (c *Client) Copy(ctx context.Context, req *CopyRequest) error {
|
||||
@@ -428,3 +459,21 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
// Signout will signout a client for a local ollama server.
|
||||
func (c *Client) Signout(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||
}
|
||||
|
||||
// Disconnect will disconnect an ollama instance from ollama.com.
|
||||
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
|
||||
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
||||
}
|
||||
|
||||
func (c *Client) Whoami(ctx context.Context) (*UserResponse, error) {
|
||||
var resp UserResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/me", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
160
api/types.go
160
api/types.go
@@ -11,6 +11,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -36,6 +38,19 @@ func (e StatusError) Error() string {
|
||||
}
|
||||
}
|
||||
|
||||
type AuthorizationError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
SigninURL string `json:"signin_url"`
|
||||
}
|
||||
|
||||
func (e AuthorizationError) Error() string {
|
||||
if e.Status != "" {
|
||||
return e.Status
|
||||
}
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
@@ -313,13 +328,29 @@ func (t *ToolFunction) String() string {
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
|
||||
// RemoteModel is the name of the upstream model that generated the response.
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Message contains the message or part of a message from the model.
|
||||
Message Message `json:"message"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
// DoneReason is the reason the model stopped generating text.
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
Metrics
|
||||
}
|
||||
|
||||
@@ -329,13 +360,6 @@ type DebugInfo struct {
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
||||
type DebugTemplateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DebugInfo DebugInfo `json:"_debug_info"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
@@ -431,18 +455,47 @@ type EmbeddingResponse struct {
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
// Model is the model name to create.
|
||||
Model string `json:"model"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Quantize is the quantization format for the model; leave blank to not change the quantization level.
|
||||
Quantize string `json:"quantize,omitempty"`
|
||||
|
||||
From string `json:"from,omitempty"`
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
Adapters map[string]string `json:"adapters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
License any `json:"license,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
// From is the name of the model or file to use as the source.
|
||||
From string `json:"from,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream ollama API for the model (if any).
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// Files is a map of files include when creating the model.
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
|
||||
// Adapters is a map of LoRA adapters to include when creating the model.
|
||||
Adapters map[string]string `json:"adapters,omitempty"`
|
||||
|
||||
// Template is the template used when constructing a request to the model.
|
||||
Template string `json:"template,omitempty"`
|
||||
|
||||
// License is a string or list of strings for licenses.
|
||||
License any `json:"license,omitempty"`
|
||||
|
||||
// System is the system prompt for the model.
|
||||
System string `json:"system,omitempty"`
|
||||
|
||||
// Parameters is a map of hyper-parameters which are applied to the model.
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
|
||||
// Messages is a list of messages added to the model before chat and generation requests.
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Info is a map of additional information for the model
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -480,8 +533,12 @@ type ShowResponse struct {
|
||||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||
Tensors []Tensor `json:"tensors,omitempty"`
|
||||
@@ -540,12 +597,14 @@ type ProcessResponse struct {
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
@@ -569,6 +628,12 @@ type GenerateResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
|
||||
// RemoteModel is the name of the upstream model that generated the response.
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// RemoteHost is the URL of the upstream Ollama host that generated the response.
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
@@ -592,6 +657,8 @@ type GenerateResponse struct {
|
||||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
@@ -604,6 +671,18 @@ type ModelDetails struct {
|
||||
QuantizationLevel string `json:"quantization_level"`
|
||||
}
|
||||
|
||||
// UserResponse provides information about a user.
|
||||
type UserResponse struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Bio string `json:"bio,omitempty"`
|
||||
AvatarURL string `json:"avatarurl,omitempty"`
|
||||
FirstName string `json:"firstname,omitempty"`
|
||||
LastName string `json:"lastname,omitempty"`
|
||||
Plan string `json:"plan,omitempty"`
|
||||
}
|
||||
|
||||
// Tensor describes the metadata for a given tensor.
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
@@ -979,3 +1058,30 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Web search types
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type SearchResponse struct {
|
||||
Results []SearchResult `json:"results"`
|
||||
}
|
||||
|
||||
// Web fetch types
|
||||
type FetchRequest struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type FetchResponse struct {
|
||||
Content string `json:"content"`
|
||||
Title string `json:"title,omitempty"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
15
auth/auth.go
15
auth/auth.go
@@ -18,21 +18,13 @@ import (
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
func GetPublicKey() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
@@ -59,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
|
||||
214
cmd/cmd.go
214
cmd/cmd.go
@@ -47,6 +47,8 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
@@ -286,7 +288,17 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
return client.Generate(cmd.Context(), req, func(r api.GenerateResponse) error {
|
||||
if r.RemoteModel != "" && opts.ShowConnect {
|
||||
p.StopAndClear()
|
||||
if strings.HasPrefix(r.RemoteHost, "https://ollama.com") {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", r.RemoteModel)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", r.RemoteModel, r.RemoteHost)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func StopHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -307,9 +319,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]any{},
|
||||
ShowConnect: true,
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
@@ -367,6 +380,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.ShowConnect = false
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
@@ -433,6 +447,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
|
||||
if sErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, sErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -453,6 +476,59 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.Name != "" {
|
||||
fmt.Printf("You are already signed in as user '%s'\n", user.Name)
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.Signout(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You are not signed in to ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("You have signed out of ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -464,6 +540,25 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
_, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You need to be signed in to push models to ollama.com.")
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
@@ -500,12 +595,12 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
errStr := strings.ToLower(err.Error())
|
||||
if strings.Contains(errStr, "access denied") || strings.Contains(errStr, "unauthorized") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
return err
|
||||
@@ -539,7 +634,14 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
var size string
|
||||
if m.RemoteModel != "" {
|
||||
size = "-"
|
||||
} else {
|
||||
size = format.HumanBytes(m.Size)
|
||||
}
|
||||
|
||||
data = append(data, []string{m.Name, m.Digest[:12], size, format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -624,8 +726,8 @@ func DeleteHandler(cmd *cobra.Command, args []string) error {
|
||||
KeepAlive: &api.Duration{Duration: 0},
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, opts); err != nil {
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
return fmt.Errorf("unable to stop existing running model \"%s\": %s", args[0], err)
|
||||
if !strings.Contains(strings.ToLower(err.Error()), "not found") {
|
||||
fmt.Fprintf(os.Stderr, "Warning: unable to stop model '%s'\n", args[0])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -736,12 +838,36 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||
}
|
||||
|
||||
tableRender("Model", func() (rows [][]string) {
|
||||
if resp.RemoteHost != "" {
|
||||
rows = append(rows, []string{"", "Remote model", resp.RemoteModel})
|
||||
rows = append(rows, []string{"", "Remote URL", resp.RemoteHost})
|
||||
}
|
||||
|
||||
if resp.ModelInfo != nil {
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
rows = append(rows, []string{"", "architecture", arch})
|
||||
rows = append(rows, []string{"", "parameters", format.HumanNumber(uint64(resp.ModelInfo["general.parameter_count"].(float64)))})
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64), 'f', -1, 64)})
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)].(float64), 'f', -1, 64)})
|
||||
|
||||
var paramStr string
|
||||
if resp.Details.ParameterSize != "" {
|
||||
paramStr = resp.Details.ParameterSize
|
||||
} else if v, ok := resp.ModelInfo["general.parameter_count"]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
paramStr = format.HumanNumber(uint64(f))
|
||||
}
|
||||
}
|
||||
rows = append(rows, []string{"", "parameters", paramStr})
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok {
|
||||
if f, ok := v.(float64); ok {
|
||||
rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows = append(rows, []string{"", "architecture", resp.Details.Family})
|
||||
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
|
||||
@@ -989,6 +1115,52 @@ type runOptions struct {
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
ShowConnect bool
|
||||
}
|
||||
|
||||
func (r runOptions) Copy() runOptions {
|
||||
var messages []api.Message
|
||||
if r.Messages != nil {
|
||||
messages = make([]api.Message, len(r.Messages))
|
||||
copy(messages, r.Messages)
|
||||
}
|
||||
|
||||
var images []api.ImageData
|
||||
if r.Images != nil {
|
||||
images = make([]api.ImageData, len(r.Images))
|
||||
copy(images, r.Images)
|
||||
}
|
||||
|
||||
var opts map[string]any
|
||||
if r.Options != nil {
|
||||
opts = make(map[string]any, len(r.Options))
|
||||
for k, v := range r.Options {
|
||||
opts[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
if r.Think != nil {
|
||||
cThink := *r.Think
|
||||
think = &cThink
|
||||
}
|
||||
|
||||
return runOptions{
|
||||
Model: r.Model,
|
||||
ParentModel: r.ParentModel,
|
||||
Prompt: r.Prompt,
|
||||
Messages: messages,
|
||||
WordWrap: r.WordWrap,
|
||||
Format: r.Format,
|
||||
System: r.System,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
MultiModal: r.MultiModal,
|
||||
KeepAlive: r.KeepAlive,
|
||||
Think: think,
|
||||
HideThinking: r.HideThinking,
|
||||
ShowConnect: r.ShowConnect,
|
||||
}
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -1544,6 +1716,22 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
pushCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
signinCmd := &cobra.Command{
|
||||
Use: "signin",
|
||||
Short: "Sign in to ollama.com",
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SigninHandler,
|
||||
}
|
||||
|
||||
signoutCmd := &cobra.Command{
|
||||
Use: "signout",
|
||||
Short: "Sign out from ollama.com",
|
||||
Args: cobra.ExactArgs(0),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: SignoutHandler,
|
||||
}
|
||||
|
||||
listCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
@@ -1638,6 +1826,8 @@ func NewCLI() *cobra.Command {
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
signinCmd,
|
||||
signoutCmd,
|
||||
listCmd,
|
||||
psCmd,
|
||||
copyCmd,
|
||||
|
||||
328
cmd/cmd_test.go
328
cmd/cmd_test.go
@@ -3,10 +3,12 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -304,6 +306,8 @@ func TestDeleteHandler(t *testing.T) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
errPayload := `{"error":"model '%s' not found"}`
|
||||
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -346,7 +350,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
err := DeleteHandler(cmd, []string{"test-model-not-found"})
|
||||
if err == nil || !strings.Contains(err.Error(), "unable to stop existing running model \"test-model-not-found\"") {
|
||||
if err == nil || !strings.Contains(err.Error(), "model 'test-model-not-found' not found") {
|
||||
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -488,9 +492,35 @@ func TestPushHandler(t *testing.T) {
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
},
|
||||
{
|
||||
name: "not signed in push",
|
||||
modelName: "notsignedin-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "unauthorized",
|
||||
"signin_url": "https://somethingsomething",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "You need to be signed in to push",
|
||||
},
|
||||
{
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
@@ -499,12 +529,17 @@ func TestPushHandler(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "access denied",
|
||||
"error": "403: {\"errors\":[{\"code\":\"ACCESS DENIED\", \"message\":\"access denied\"}]}",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
"/api/me": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
@@ -522,6 +557,10 @@ func TestPushHandler(t *testing.T) {
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("HOME", tmpDir)
|
||||
t.Setenv("USERPROFILE", tmpDir)
|
||||
initializeKeypair()
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
@@ -557,7 +596,7 @@ func TestPushHandler(t *testing.T) {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
if got := string(stdout); !strings.Contains(got, tt.expectedOutput) {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
@@ -915,3 +954,286 @@ func TestNewCreateRequest(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy(t *testing.T) {
|
||||
// Setup test data
|
||||
originalKeepAlive := &api.Duration{Duration: 5 * time.Minute}
|
||||
originalThink := &api.ThinkValue{Value: "test reasoning"}
|
||||
|
||||
original := runOptions{
|
||||
Model: "test-model",
|
||||
ParentModel: "parent-model",
|
||||
Prompt: "test prompt",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi there"},
|
||||
},
|
||||
WordWrap: true,
|
||||
Format: "json",
|
||||
System: "system prompt",
|
||||
Images: []api.ImageData{
|
||||
[]byte("image1"),
|
||||
[]byte("image2"),
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
MultiModal: true,
|
||||
KeepAlive: originalKeepAlive,
|
||||
Think: originalThink,
|
||||
HideThinking: false,
|
||||
ShowConnect: true,
|
||||
}
|
||||
|
||||
// Test the copy
|
||||
copied := original.Copy()
|
||||
|
||||
// Test 1: Verify the copy is not the same instance
|
||||
if &copied == &original {
|
||||
t.Error("Copy should return a different instance")
|
||||
}
|
||||
|
||||
// Test 2: Verify all fields are copied correctly
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
want interface{}
|
||||
}{
|
||||
{"Model", copied.Model, original.Model},
|
||||
{"ParentModel", copied.ParentModel, original.ParentModel},
|
||||
{"Prompt", copied.Prompt, original.Prompt},
|
||||
{"WordWrap", copied.WordWrap, original.WordWrap},
|
||||
{"Format", copied.Format, original.Format},
|
||||
{"System", copied.System, original.System},
|
||||
{"MultiModal", copied.MultiModal, original.MultiModal},
|
||||
{"HideThinking", copied.HideThinking, original.HideThinking},
|
||||
{"ShowConnect", copied.ShowConnect, original.ShowConnect},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if !reflect.DeepEqual(tt.got, tt.want) {
|
||||
t.Errorf("%s mismatch: got %v, want %v", tt.name, tt.got, tt.want)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 3: Verify Messages slice is deeply copied
|
||||
if len(copied.Messages) != len(original.Messages) {
|
||||
t.Errorf("Messages length mismatch: got %d, want %d", len(copied.Messages), len(original.Messages))
|
||||
}
|
||||
|
||||
if len(copied.Messages) > 0 && &copied.Messages[0] == &original.Messages[0] {
|
||||
t.Error("Messages should be different instances")
|
||||
}
|
||||
|
||||
// Modify original to verify independence
|
||||
if len(original.Messages) > 0 {
|
||||
originalContent := original.Messages[0].Content
|
||||
original.Messages[0].Content = "modified"
|
||||
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||
t.Error("Messages should be independent after copy")
|
||||
}
|
||||
// Restore for other tests
|
||||
original.Messages[0].Content = originalContent
|
||||
}
|
||||
|
||||
// Test 4: Verify Images slice is deeply copied
|
||||
if len(copied.Images) != len(original.Images) {
|
||||
t.Errorf("Images length mismatch: got %d, want %d", len(copied.Images), len(original.Images))
|
||||
}
|
||||
|
||||
if len(copied.Images) > 0 && &copied.Images[0] == &original.Images[0] {
|
||||
t.Error("Images should be different instances")
|
||||
}
|
||||
|
||||
// Modify original to verify independence
|
||||
if len(original.Images) > 0 {
|
||||
originalImage := original.Images[0]
|
||||
original.Images[0] = []byte("modified")
|
||||
if len(copied.Images) > 0 && string(copied.Images[0]) == "modified" {
|
||||
t.Error("Images should be independent after copy")
|
||||
}
|
||||
// Restore for other tests
|
||||
original.Images[0] = originalImage
|
||||
}
|
||||
|
||||
// Test 5: Verify Options map is deeply copied
|
||||
if len(copied.Options) != len(original.Options) {
|
||||
t.Errorf("Options length mismatch: got %d, want %d", len(copied.Options), len(original.Options))
|
||||
}
|
||||
|
||||
if len(copied.Options) > 0 && &copied.Options == &original.Options {
|
||||
t.Error("Options map should be different instances")
|
||||
}
|
||||
|
||||
// Modify original to verify independence
|
||||
if len(original.Options) > 0 {
|
||||
originalTemp := original.Options["temperature"]
|
||||
original.Options["temperature"] = 0.9
|
||||
if copied.Options["temperature"] == 0.9 {
|
||||
t.Error("Options should be independent after copy")
|
||||
}
|
||||
// Restore for other tests
|
||||
original.Options["temperature"] = originalTemp
|
||||
}
|
||||
|
||||
// Test 6: Verify KeepAlive pointer is copied (shallow copy)
|
||||
if copied.KeepAlive != original.KeepAlive {
|
||||
t.Error("KeepAlive pointer should be the same (shallow copy)")
|
||||
}
|
||||
|
||||
// Test 7: Verify Think pointer creates a new instance
|
||||
if original.Think != nil && copied.Think == original.Think {
|
||||
t.Error("Think should be a different instance")
|
||||
}
|
||||
|
||||
if original.Think != nil && copied.Think != nil {
|
||||
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
|
||||
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 8: Test with zero values
|
||||
zeroOriginal := runOptions{}
|
||||
zeroCopy := zeroOriginal.Copy()
|
||||
|
||||
if !reflect.DeepEqual(zeroCopy, zeroOriginal) {
|
||||
fmt.Printf("orig: %#v\ncopy: %#v\n", zeroOriginal, zeroCopy)
|
||||
t.Error("Copy of zero value should equal original zero value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_EmptySlicesAndMaps(t *testing.T) {
|
||||
// Test with empty slices and maps
|
||||
original := runOptions{
|
||||
Messages: []api.Message{},
|
||||
Images: []api.ImageData{},
|
||||
Options: map[string]any{},
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
if copied.Messages == nil {
|
||||
t.Error("Empty Messages slice should remain empty, not nil")
|
||||
}
|
||||
|
||||
if copied.Images == nil {
|
||||
t.Error("Empty Images slice should remain empty, not nil")
|
||||
}
|
||||
|
||||
if copied.Options == nil {
|
||||
t.Error("Empty Options map should remain empty, not nil")
|
||||
}
|
||||
|
||||
if len(copied.Messages) != 0 {
|
||||
t.Error("Empty Messages slice should remain empty")
|
||||
}
|
||||
|
||||
if len(copied.Images) != 0 {
|
||||
t.Error("Empty Images slice should remain empty")
|
||||
}
|
||||
|
||||
if len(copied.Options) != 0 {
|
||||
t.Error("Empty Options map should remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_NilPointers(t *testing.T) {
|
||||
// Test with nil pointers
|
||||
original := runOptions{
|
||||
KeepAlive: nil,
|
||||
Think: nil,
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
if copied.KeepAlive != nil {
|
||||
t.Error("Nil KeepAlive should remain nil")
|
||||
}
|
||||
|
||||
if copied.Think != nil {
|
||||
t.Error("Nil Think should remain nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_ThinkValueVariants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
think *api.ThinkValue
|
||||
}{
|
||||
{"nil Think", nil},
|
||||
{"bool true", &api.ThinkValue{Value: true}},
|
||||
{"bool false", &api.ThinkValue{Value: false}},
|
||||
{"string value", &api.ThinkValue{Value: "reasoning text"}},
|
||||
{"int value", &api.ThinkValue{Value: 42}},
|
||||
{"nil value", &api.ThinkValue{Value: nil}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
original := runOptions{Think: tt.think}
|
||||
copied := original.Copy()
|
||||
|
||||
if tt.think == nil {
|
||||
if copied.Think != nil {
|
||||
t.Error("Nil Think should remain nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if copied.Think == nil {
|
||||
t.Error("Non-nil Think should not become nil")
|
||||
return
|
||||
}
|
||||
|
||||
if copied.Think == original.Think {
|
||||
t.Error("Think should be a different instance")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(copied.Think.Value, original.Think.Value) {
|
||||
t.Errorf("Think.Value mismatch: got %v, want %v", copied.Think.Value, original.Think.Value)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOptions_Copy_Independence(t *testing.T) {
|
||||
// Test that modifications to original don't affect copy
|
||||
originalThink := &api.ThinkValue{Value: "original"}
|
||||
original := runOptions{
|
||||
Model: "original-model",
|
||||
Messages: []api.Message{{Role: "user", Content: "original"}},
|
||||
Options: map[string]any{"key": "value"},
|
||||
Think: originalThink,
|
||||
}
|
||||
|
||||
copied := original.Copy()
|
||||
|
||||
// Modify original
|
||||
original.Model = "modified-model"
|
||||
if len(original.Messages) > 0 {
|
||||
original.Messages[0].Content = "modified"
|
||||
}
|
||||
original.Options["key"] = "modified"
|
||||
if original.Think != nil {
|
||||
original.Think.Value = "modified"
|
||||
}
|
||||
|
||||
// Verify copy is unchanged
|
||||
if copied.Model == "modified-model" {
|
||||
t.Error("Copy Model should not be affected by original modification")
|
||||
}
|
||||
|
||||
if len(copied.Messages) > 0 && copied.Messages[0].Content == "modified" {
|
||||
t.Error("Copy Messages should not be affected by original modification")
|
||||
}
|
||||
|
||||
if copied.Options["key"] == "modified" {
|
||||
t.Error("Copy Options should not be affected by original modification")
|
||||
}
|
||||
|
||||
if copied.Think != nil && copied.Think.Value == "modified" {
|
||||
t.Error("Copy Think should not be affected by original modification")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,16 +195,24 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Println("Usage:\n /load <modelname>")
|
||||
continue
|
||||
}
|
||||
origOpts := opts.Copy()
|
||||
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||
opts = origOpts.Copy()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
fmt.Printf("Couldn't find model '%s'\n", opts.Model)
|
||||
opts = origOpts.Copy()
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
|
||||
@@ -28,6 +28,7 @@ type bertModel struct {
|
||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
normalizeEmbeddings bool
|
||||
|
||||
PoolingType uint32
|
||||
}
|
||||
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
||||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
if m.Type == "sentence_transformers.models.Pooling" {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
pooling = m.Path
|
||||
break
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
kv["bert.pooling_type"] = p.PoolingType
|
||||
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ type safetensor struct {
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
|
||||
@@ -230,3 +230,65 @@ func TestSafetensors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafetensorKind(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
st safetensor
|
||||
expected uint32
|
||||
}{
|
||||
{
|
||||
name: "BF16 dtype with non-v. prefix and non-FP32 base kind should return BF16",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindBF16,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with v. prefix should return base kind",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "v.weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
{
|
||||
name: "BF16 dtype with FP32 base kind should return FP32",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10}, // will default to FP32
|
||||
},
|
||||
dtype: "BF16",
|
||||
},
|
||||
expected: tensorKindFP32,
|
||||
},
|
||||
{
|
||||
name: "Non-BF16 dtype should return base kind",
|
||||
st: safetensor{
|
||||
tensorBase: &tensorBase{
|
||||
name: "weight.matrix",
|
||||
shape: []uint64{10, 10}, // will default to FP16
|
||||
},
|
||||
dtype: "FP16",
|
||||
},
|
||||
expected: tensorKindFP16,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.st.Kind()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Kind() = %d, expected %d", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
func cudaVariant(gpuInfos []CudaGPUInfo) string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
ver := strings.Split(CudaTegra, ".")
|
||||
@@ -45,12 +45,19 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
}
|
||||
}
|
||||
|
||||
if gpuInfo.DriverMajor < 13 {
|
||||
// Check GPU compute capability FIRST, lowest common denominator if multi-gpu
|
||||
for _, gpuInfo := range gpuInfos {
|
||||
if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
|
||||
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
|
||||
return "v12"
|
||||
}
|
||||
}
|
||||
|
||||
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
|
||||
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
|
||||
// The detected driver is older than 580 (Aug 2025)
|
||||
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
||||
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
}
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
|
||||
return "v12"
|
||||
}
|
||||
return "v13"
|
||||
|
||||
@@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList {
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
variant := cudaVariant(gpuInfo)
|
||||
|
||||
// Start with our bundled libraries
|
||||
if variant != "" {
|
||||
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||
if _, err := os.Stat(variantPath); err == nil {
|
||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
|
||||
}
|
||||
}
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.Variant = variant
|
||||
|
||||
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
|
||||
unsupportedGPUs = append(unsupportedGPUs,
|
||||
@@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList {
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
|
||||
variant := cudaVariant(cudaGPUs)
|
||||
var variantPath string
|
||||
// Start with our bundled libraries
|
||||
if variant != "" {
|
||||
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
|
||||
if _, err := os.Stat(variantPath); err != nil {
|
||||
variantPath = ""
|
||||
}
|
||||
}
|
||||
|
||||
for i := range cudaGPUs {
|
||||
cudaGPUs[i].Variant = variant
|
||||
if variantPath != "" {
|
||||
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
|
||||
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Intel
|
||||
|
||||
40
docs/cloud.md
Normal file
40
docs/cloud.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Cloud
|
||||
|
||||
| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud).
|
||||
|
||||
## Cloud Models
|
||||
|
||||
[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn’t fit on a personal computer.
|
||||
|
||||
Ollama currently supports the following cloud models, with more coming soon:
|
||||
|
||||
- `gpt-oss:20b-cloud`
|
||||
- `gpt-oss:120b-cloud`
|
||||
- `deepseek-v3.1:671b-cloud`
|
||||
- `qwen3-coder:480b-cloud`
|
||||
|
||||
### Get started
|
||||
|
||||
To run a cloud model, open the terminal and run:
|
||||
|
||||
```
|
||||
ollama run gpt-oss:120b-cloud
|
||||
```
|
||||
|
||||
To run cloud models with integrations that work with Ollama, first download the cloud model:
|
||||
|
||||
```
|
||||
ollama pull qwen3-coder:480b-cloud
|
||||
```
|
||||
|
||||
Then sign in to Ollama:
|
||||
|
||||
```
|
||||
ollama signin
|
||||
```
|
||||
|
||||
Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling.
|
||||
|
||||
## Cloud API access
|
||||
|
||||
Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud).
|
||||
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
||||
go run . serve
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||
|
||||
|
||||
## macOS (Apple Silicon)
|
||||
|
||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||
|
||||
107
docs/turbo.md
107
docs/turbo.md
@@ -1,107 +0,0 @@
|
||||
# Turbo
|
||||
|
||||
> ⚠️ Turbo is preview
|
||||
|
||||
Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
|
||||
|
||||
Currently, the following models are available in Turbo:
|
||||
|
||||
- `gpt-oss:20b`
|
||||
- `gpt-oss:120b`
|
||||
|
||||
## Get started
|
||||
|
||||
### Ollama for macOS & Windows
|
||||
|
||||
Download Ollama
|
||||
|
||||
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
|
||||
- Click on **Turbo**. You’ll be prompted to create an account or sign in
|
||||
|
||||
### Ollama’s CLI
|
||||
|
||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
|
||||
|
||||
On macOS and Linux:
|
||||
|
||||
```shell
|
||||
cat ~/.ollama/id_ed25519.pub
|
||||
```
|
||||
|
||||
On Windows:
|
||||
|
||||
```
|
||||
type "%USERPROFILE%\.ollama\id_ed25519.pub"
|
||||
```
|
||||
|
||||
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
|
||||
```shell
|
||||
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
|
||||
```
|
||||
|
||||
### Ollama’s Python library
|
||||
|
||||
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
|
||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||
- Create an API key by visiting https://ollama.com/settings/keys
|
||||
|
||||
```python
|
||||
from ollama import Client
|
||||
|
||||
client = Client(
|
||||
host="https://ollama.com",
|
||||
headers={'Authorization': '<api key>'}
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
},
|
||||
]
|
||||
|
||||
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
|
||||
print(part['message']['content'], end='', flush=True)
|
||||
```
|
||||
|
||||
### Ollama’s JavaScript library
|
||||
|
||||
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
|
||||
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||
- Create an API key by visiting https://ollama.com/settings/keys
|
||||
|
||||
```typescript
|
||||
import { Ollama } from 'ollama';
|
||||
|
||||
const ollama = new Ollama({
|
||||
host: 'https://ollama.com',
|
||||
headers: {
|
||||
Authorization: "Bearer <api key>"
|
||||
}
|
||||
});
|
||||
|
||||
const response = await ollama.chat({
|
||||
model: 'gpt-oss:120b',
|
||||
messages: [{ role: 'user', content: 'Explain quantum computing' }],
|
||||
stream: true
|
||||
});
|
||||
|
||||
for await (const part of response) {
|
||||
process.stdout.write(part.message.content)
|
||||
}
|
||||
```
|
||||
|
||||
### Community integrations
|
||||
|
||||
Turbo mode is also compatible with several community integrations.
|
||||
|
||||
#### Open WebUI
|
||||
|
||||
- Go to **settings** → **Admin settings** → **Connections**
|
||||
- Under **Ollama API,** click **+**
|
||||
- For the **URL** put `https://ollama.com`
|
||||
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
|
||||
- Click **Save**
|
||||
|
||||
Now, if you navigate to the model selector, Turbo models should be available under **External**.
|
||||
@@ -134,6 +134,17 @@ func LoadTimeout() (loadTimeout time.Duration) {
|
||||
return loadTimeout
|
||||
}
|
||||
|
||||
func Remotes() []string {
|
||||
var r []string
|
||||
raw := strings.TrimSpace(Var("OLLAMA_REMOTES"))
|
||||
if raw == "" {
|
||||
r = []string{"ollama.com"}
|
||||
} else {
|
||||
r = strings.Split(raw, ",")
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func Bool(k string) func() bool {
|
||||
return func() bool {
|
||||
if s := Var(k); s != "" {
|
||||
@@ -270,6 +281,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
@@ -243,6 +243,8 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
"llama4",
|
||||
"mllama",
|
||||
"qwen25vl",
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
@@ -89,28 +76,18 @@ func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func Prefill(lastMessage api.Message) string {
|
||||
if lastMessage.Role != "assistant" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(lastMessage.Content) != "":
|
||||
return "<|start|>assistant<|channel|>final<|message|>"
|
||||
case strings.TrimSpace(lastMessage.Thinking) != "":
|
||||
return "<|start|>assistant<|channel|>analysis<|message|>"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
|
||||
if strings.TrimSpace(prefillString) != "" {
|
||||
s.acc.WriteString(prefillString)
|
||||
} else {
|
||||
s.AddImplicitStart()
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
@@ -289,7 +266,8 @@ type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
ToolParser *HarmonyToolCallAccumulator
|
||||
toolAccumulator *HarmonyToolCallAccumulator
|
||||
convertedTools map[string]struct{}
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
@@ -302,16 +280,13 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
ToolParser: &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
},
|
||||
convertedTools: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
|
||||
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||
contentSb := strings.Builder{}
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
@@ -328,14 +303,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
||||
// event.Header.Recipient is the tool name, something like
|
||||
// "browser.search" for a built-in, or "functions.calc" for a
|
||||
// custom one
|
||||
h.ToolParser.SetToolName(event.Header.Recipient)
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Thinking
|
||||
}
|
||||
case "commentary":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
h.ToolParser.SetToolName(event.Header.Recipient)
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
@@ -358,6 +333,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
||||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||
}
|
||||
|
||||
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||
return &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type harmonyToolCallState int
|
||||
|
||||
const (
|
||||
@@ -406,8 +388,85 @@ func NewFunctionNameMap() *FunctionNameMap {
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the handler with tools and optional last message
|
||||
// Implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
// Initialize the harmony parser
|
||||
if h.HarmonyParser == nil {
|
||||
h.HarmonyParser = &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
}
|
||||
}
|
||||
|
||||
// Handle prefill for chat mode
|
||||
if lastMessage != nil {
|
||||
h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
} else {
|
||||
h.HarmonyParser.AddImplicitStart()
|
||||
}
|
||||
|
||||
// Initialize tool accumulator
|
||||
h.toolAccumulator = h.CreateToolParser()
|
||||
|
||||
// Process tools and return renamed versions
|
||||
if len(tools) == 0 {
|
||||
return tools
|
||||
}
|
||||
|
||||
processedTools := make([]api.Tool, len(tools))
|
||||
copy(processedTools, tools)
|
||||
for i, tool := range processedTools {
|
||||
if tool.Function.Name != "" {
|
||||
processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
h.convertedTools[tool.Function.Name] = struct{}{}
|
||||
}
|
||||
}
|
||||
return processedTools
|
||||
}
|
||||
|
||||
// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls
|
||||
func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
content, thinking, toolContent := h.AddContent(s, h.toolAccumulator)
|
||||
if toolContent != "" {
|
||||
h.toolAccumulator.Add(toolContent)
|
||||
}
|
||||
|
||||
// tool calls always happen one at a time, and always at the end of a message,
|
||||
// so for simplicity we defer parsing them until we know we're done
|
||||
if done {
|
||||
toolName, raw := h.toolAccumulator.Drain()
|
||||
if toolName != nil {
|
||||
name := strings.TrimPrefix(*toolName, "functions.")
|
||||
name = h.FunctionNameMap.OriginalFromConverted(name)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
|
||||
}
|
||||
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
|
||||
}
|
||||
}
|
||||
|
||||
return content, thinking, calls, nil
|
||||
}
|
||||
|
||||
// HasToolSupport implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// HasThinkingSupport implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||
// built-in functions should not be renamed
|
||||
if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" {
|
||||
harmonyFunctionName = userFunctionName
|
||||
}
|
||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||
return harmonyFunctionName
|
||||
|
||||
@@ -3,7 +3,6 @@ package harmony
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -514,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||
{name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
@@ -536,202 +536,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
|
||||
t.Run("thinking_then_content_streams", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
type step struct {
|
||||
in string
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}
|
||||
steps := []step{
|
||||
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
|
||||
{in: "<|end|>", wantThinking: ""},
|
||||
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
|
||||
{in: "<|end|>", wantContent: ""},
|
||||
}
|
||||
for i, s := range steps {
|
||||
content, thinking, tool := handler.AddContent(s.in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if content != s.wantContent || thinking != s.wantThinking {
|
||||
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|start|>assistant<|message|>Hello",
|
||||
", world",
|
||||
"!<|end|>",
|
||||
}
|
||||
var got []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("unexpected thinking %q", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
got = append(got, content)
|
||||
}
|
||||
}
|
||||
want := []string{"Hello", ", world", "!"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>analysis<|message|>Thinking...",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Answer",
|
||||
"<|end|>",
|
||||
}
|
||||
var got []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
got = append(got, thinking)
|
||||
}
|
||||
if content != "" {
|
||||
got = append(got, content)
|
||||
}
|
||||
}
|
||||
want := []string{"Thinking...", "Answer"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|chan",
|
||||
"nel|>analysis<|mess",
|
||||
"age|>Deep ",
|
||||
"thought",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Done",
|
||||
"<|end|>",
|
||||
}
|
||||
var thinkingPieces []string
|
||||
var contentPieces []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
thinkingPieces = append(thinkingPieces, thinking)
|
||||
}
|
||||
if content != "" {
|
||||
contentPieces = append(contentPieces, content)
|
||||
}
|
||||
}
|
||||
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
|
||||
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
|
||||
}
|
||||
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>analysis<|message|>Think",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Answer",
|
||||
"<|end|>",
|
||||
}
|
||||
var contentSb, thinkingSb strings.Builder
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
contentSb.WriteString(content)
|
||||
thinkingSb.WriteString(thinking)
|
||||
}
|
||||
if contentSb.String() != "Answer" {
|
||||
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
|
||||
}
|
||||
if thinkingSb.String() != "Think" {
|
||||
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if content != "" || thinking != "" {
|
||||
continue
|
||||
}
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
}
|
||||
name, args := tp.Drain()
|
||||
if name == nil || *name != "functions.calculate" {
|
||||
t.Fatalf("unexpected tool name: %v", name)
|
||||
}
|
||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_call_across_chunks", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
|
||||
"2\"}",
|
||||
"<|end|>",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if content != "" || thinking != "" {
|
||||
continue
|
||||
}
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
}
|
||||
name, args := tp.Drain()
|
||||
if name == nil || *name != "functions.calculate" {
|
||||
t.Fatalf("unexpected tool name: %v", name)
|
||||
}
|
||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,3 +12,6 @@ The integration tests have 2 modes of operating.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
|
||||
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`
|
||||
@@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue? be brief",
|
||||
Prompt: blueSkyPrompt,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
@@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) {
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
t.Errorf("none of %v found in %s", blueSkyExpected, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
@@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) {
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "why is the sky blue? be brief",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Options: map[string]interface{}{
|
||||
@@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) {
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering"}
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
@@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) {
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Errorf("none of %v found in %s", anyResp, response)
|
||||
t.Errorf("none of %v found in %s", blueSkyExpected, response)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for chat")
|
||||
|
||||
@@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
Prompt: blueSkyPrompt,
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
func TestUnicode(t *testing.T) {
|
||||
@@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "why is the sky blue?",
|
||||
Prompt: blueSkyPrompt,
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"})
|
||||
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Prompt: "Write me a story in english with a lot of emojis",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
@@ -113,8 +113,48 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: rainbowPrompt,
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"num_ctx": 16384,
|
||||
},
|
||||
}
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
|
||||
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
|
||||
for i := 0; i < len(rainbowFollowups); i++ {
|
||||
req.Prompt = rainbowFollowups[i]
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
func TestParallelChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
@@ -164,3 +204,55 @@ func TestChatWithHistory(t *testing.T) {
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"num_ctx": 16384,
|
||||
},
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: rainbowPrompt,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
|
||||
assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
|
||||
for i := 0; i < len(rainbowFollowups); i++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
req.Messages = append(req.Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: rainbowFollowups[i]},
|
||||
)
|
||||
|
||||
assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
@@ -44,9 +45,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embedding) != 384 {
|
||||
@@ -74,9 +74,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 1 {
|
||||
@@ -112,9 +111,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 2 {
|
||||
@@ -156,93 +154,135 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
|
||||
truncTrue, truncFalse := true, false
|
||||
|
||||
type testReq struct {
|
||||
Name string
|
||||
Request api.EmbedRequest
|
||||
want, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
reqs := []testReq{
|
||||
cases := []struct {
|
||||
name string
|
||||
request api.EmbedRequest
|
||||
check func(*api.EmbedResponse, error)
|
||||
}{
|
||||
{
|
||||
Name: "Target Truncation",
|
||||
Request: api.EmbedRequest{
|
||||
name: "target truncation",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Default Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Explicit Truncate",
|
||||
Request: api.EmbedRequest{
|
||||
name: "default truncate",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "explicit truncate",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(got *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
|
||||
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 3},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "input after truncate error",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 0},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err.Error() != "input after truncation exceeds maximum context length" {
|
||||
t.Fatalf("expected truncation error, got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
res := make(map[string]*api.EmbedResponse)
|
||||
|
||||
for _, req := range reqs {
|
||||
response, err := embedTestHelper(ctx, client, t, req.Request)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
res[req.Name] = response
|
||||
}
|
||||
|
||||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request to truncate correctly")
|
||||
}
|
||||
|
||||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
|
||||
t.Fatal("expected default request and truncate true request to be the same")
|
||||
}
|
||||
|
||||
// check that truncate set to false returns an error if context length is exceeded
|
||||
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 1},
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
for _, req := range cases {
|
||||
t.Run(req.name, func(t *testing.T) {
|
||||
req.check(embedTestHelper(ctx, client, t, req.request))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
|
||||
t.Helper()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
response, err := client.Embeddings(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return client.Embeddings(ctx, &req)
|
||||
}
|
||||
|
||||
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
|
||||
t.Helper()
|
||||
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("failed to pull model %s: %v", req.Model, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
response, err := client.Embed(ctx, &req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||
|
||||
chatModels := libraryChatModels
|
||||
for _, model := range chatModels {
|
||||
@@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
if targetArch != "" {
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Name: model})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to show model: %s", err)
|
||||
}
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
if arch != targetArch {
|
||||
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||
}
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Prompt: blueSkyPrompt,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0.1,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
|
||||
anyResp := blueSkyExpected
|
||||
// Special cases
|
||||
if model == "duckdb-nsql" {
|
||||
anyResp = []string{"select", "from"}
|
||||
|
||||
@@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) {
|
||||
// TODO - fiddle with context size
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Prompt: blueSkyPrompt,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,18 @@ var (
|
||||
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
||||
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
||||
func TestModelsPerf(t *testing.T) {
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
doModelPerfTest(t, ollamaEngineChatModels)
|
||||
} else {
|
||||
doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLibraryModelsPerf(t *testing.T) {
|
||||
doModelPerfTest(t, libraryChatModels)
|
||||
}
|
||||
|
||||
func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
@@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) {
|
||||
}
|
||||
longPrompt := "summarize the following: " + string(data)
|
||||
|
||||
var chatModels []string
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
chatModels = ollamaEngineChatModels
|
||||
} else {
|
||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||
}
|
||||
targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE")
|
||||
|
||||
for _, model := range chatModels {
|
||||
if !strings.Contains(model, ":") {
|
||||
model = model + ":latest"
|
||||
}
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
@@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) {
|
||||
}
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
||||
if targetArch != "" && arch != targetArch {
|
||||
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||
}
|
||||
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
@@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) {
|
||||
prompt string
|
||||
anyResp []string
|
||||
}{
|
||||
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
|
||||
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
|
||||
{blueSkyPrompt, blueSkyExpected},
|
||||
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}},
|
||||
}
|
||||
var gpuPercent int
|
||||
for _, tc := range testCases {
|
||||
@@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Round the logged prompt count for comparisons across versions/configurations which can vary slightly
|
||||
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
||||
"MODEL",
|
||||
"CONTEXT",
|
||||
"GPU PERCENT",
|
||||
"PROMPT COUNT",
|
||||
"APPROX PROMPT COUNT",
|
||||
"LOAD TIME",
|
||||
"PROMPT EVAL TPS",
|
||||
"EVAL TPS",
|
||||
@@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) {
|
||||
model,
|
||||
numCtx,
|
||||
gpuPercent,
|
||||
resp.PromptEvalCount,
|
||||
(resp.PromptEvalCount/10)*10,
|
||||
float64(resp.LoadDuration)/1000000000.0,
|
||||
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
||||
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
||||
|
||||
@@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) {
|
||||
stream := true
|
||||
genReq := api.GenerateRequest{
|
||||
Model: newName,
|
||||
Prompt: "why is the sky blue?",
|
||||
Prompt: blueSkyPrompt,
|
||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
@@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) {
|
||||
|
||||
// Some smaller quantizations can cause models to have poor quality
|
||||
// or get stuck in repetition loops, so we stop as soon as we have any matches
|
||||
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
atLeastOne := false
|
||||
var buf bytes.Buffer
|
||||
genfn := func(response api.GenerateResponse) error {
|
||||
buf.Write([]byte(response.Response))
|
||||
fullResp := strings.ToLower(buf.String())
|
||||
for _, resp := range anyResp {
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(fullResp, resp) {
|
||||
atLeastOne = true
|
||||
t.Log(fullResp)
|
||||
|
||||
@@ -256,13 +256,29 @@ var (
|
||||
"snowflake-arctic-embed",
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
|
||||
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
|
||||
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
|
||||
|
||||
rainbowPrompt = "how do rainbows form? Be brief but factual in your reply"
|
||||
rainbowFollowups = []string{
|
||||
"Explain the physics involved in them. Be breif in your reply",
|
||||
"Explain the chemistry involved in them. Be breif in your reply",
|
||||
"Explain the quantum mechanics involved in them. Be breif in your reply",
|
||||
"What are common myths related to them? Be brief in your reply",
|
||||
"What are common fairytales related to them? Be brief in your reply",
|
||||
"Can they form if there is no rain? Be breif in your reply",
|
||||
"Can they form if there are no clouds? Be breif in your reply",
|
||||
"Do they happen on other planets? Be brief in your reply",
|
||||
}
|
||||
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"}
|
||||
)
|
||||
|
||||
func init() {
|
||||
lifecycle.InitLogging()
|
||||
custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL")
|
||||
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
|
||||
if custom != "" {
|
||||
slog.Info("setting smol test model to " + custom)
|
||||
slog.Info("setting default test model to " + custom)
|
||||
smol = custom
|
||||
}
|
||||
}
|
||||
@@ -561,7 +577,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
||||
Prompt: "how do rainbows form? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
@@ -577,11 +593,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
},
|
||||
},
|
||||
[][]string{
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||
{"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"},
|
||||
{"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
|
||||
}
|
||||
nChunks := C.mtmd_input_chunks_size(ic)
|
||||
numEmbed := llamaContext.Model().NEmbd()
|
||||
lastChunkSize := 0
|
||||
embed := make([][]float32, 0)
|
||||
for i := range int(nChunks) {
|
||||
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
||||
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
||||
lastChunkSize = numTokens
|
||||
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
|
||||
|
||||
// Encode the chunk
|
||||
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
||||
return nil, errors.New("unable to encode mtmd image chunk")
|
||||
}
|
||||
}
|
||||
|
||||
// Get the embeddings
|
||||
embed := make([][]float32, lastChunkSize)
|
||||
embd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == embd {
|
||||
return nil, errors.New("failed to get image embedding")
|
||||
}
|
||||
// Get the embeddings for this chunk
|
||||
chunkEmbed := make([][]float32, numTokens)
|
||||
chunkEmbd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == chunkEmbd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range lastChunkSize {
|
||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range numTokens {
|
||||
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
}
|
||||
embed = append(embed, chunkEmbed...)
|
||||
}
|
||||
|
||||
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
|
||||
return embed, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/parser"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
@@ -1349,9 +1348,7 @@ type CompletionRequest struct {
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
ParserType parser.TokenParserType
|
||||
PrefillString string
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1378,15 +1375,13 @@ func (d DoneReason) String() string {
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
@@ -1504,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
switch {
|
||||
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
|
||||
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
@@ -1518,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||
fn(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
const LevelTrace slog.Level = -8
|
||||
@@ -29,10 +31,18 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
|
||||
}))
|
||||
}
|
||||
|
||||
type key string
|
||||
|
||||
func Trace(msg string, args ...any) {
|
||||
slog.Log(context.TODO(), LevelTrace, msg, args...)
|
||||
TraceContext(context.WithValue(context.TODO(), key("skip"), 1), msg, args...)
|
||||
}
|
||||
|
||||
func TraceContext(ctx context.Context, msg string, args ...any) {
|
||||
slog.Log(ctx, LevelTrace, msg, args...)
|
||||
if logger := slog.Default(); logger.Enabled(ctx, LevelTrace) {
|
||||
skip, _ := ctx.Value(key("skip")).(int)
|
||||
pc, _, _, _ := runtime.Caller(1 + skip)
|
||||
record := slog.NewRecord(time.Now(), LevelTrace, msg, pc)
|
||||
record.Add(args...)
|
||||
logger.Handler().Handle(ctx, record)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -416,6 +416,7 @@ type Tensor interface {
|
||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
@@ -429,12 +430,13 @@ type Tensor interface {
|
||||
Sin(ctx Context) Tensor
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
QuickGELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
|
||||
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||
if w != nil {
|
||||
@@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
|
||||
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
||||
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
||||
|
||||
@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
42
ml/nn/pooling/pooling.go
Normal file
42
ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
case TypeCLS:
|
||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||
case TypeLast:
|
||||
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
||||
return hiddenStates
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
79
ml/nn/pooling/pooling_test.go
Normal file
79
ml/nn/pooling/pooling_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package pooling_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/discover"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
)
|
||||
|
||||
func setup(tb testing.TB, n int) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := fsggml.WriteGGUF(f, fsggml.KV{
|
||||
"general.architecture": "test",
|
||||
"test.block_count": uint32(1),
|
||||
}, []*fsggml.Tensor{
|
||||
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
|
||||
}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
var gpuLayers ml.GPULayersList
|
||||
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
|
||||
gpuLayers = append(gpuLayers, ml.GPULayers{
|
||||
ID: gpus[0].ID,
|
||||
Layers: slices.Collect(func(yield func(int) bool) {
|
||||
for i := range n {
|
||||
if !yield(i) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestForward(t *testing.T) {
|
||||
cases := map[pooling.Type][]float32{
|
||||
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
|
||||
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
|
||||
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
|
||||
}
|
||||
for typ, want := range cases {
|
||||
t.Run(typ.String(), func(t *testing.T) {
|
||||
b := setup(t, 99)
|
||||
defer b.Close()
|
||||
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
|
||||
tt = typ.Forward(ctx, tt)
|
||||
|
||||
ctx.Forward(tt).Compute(tt)
|
||||
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/dlclark/regexp2"
|
||||
@@ -13,16 +14,28 @@ import (
|
||||
)
|
||||
|
||||
type BytePairEncoding struct {
|
||||
pre *regexp2.Regexp
|
||||
vocab *Vocabulary
|
||||
vocab *Vocabulary
|
||||
regexps []*regexp2.Regexp
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*BytePairEncoding)(nil)
|
||||
|
||||
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
||||
func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
|
||||
if len(pretokenizers) == 0 {
|
||||
// set default byte-level pretokenizer if none provided, e.g.
|
||||
// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
|
||||
pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
|
||||
}
|
||||
|
||||
return BytePairEncoding{
|
||||
pre: regexp2.MustCompile(pre, regexp2.None),
|
||||
vocab: vocab,
|
||||
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
|
||||
for _, p := range pretokenizers {
|
||||
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool {
|
||||
}
|
||||
|
||||
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
|
||||
if !yield(m.String()) {
|
||||
break
|
||||
parts := []string{s}
|
||||
for _, re := range bpe.regexps {
|
||||
parts = slices.Collect(func(yield func(string) bool) {
|
||||
for _, part := range parts {
|
||||
r := []rune(part)
|
||||
var offset int
|
||||
for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
|
||||
if offset-m.Index != 0 {
|
||||
if !yield(string(r[:m.Index])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if !yield(m.String()) {
|
||||
return
|
||||
}
|
||||
|
||||
offset = m.Index + m.Length
|
||||
}
|
||||
|
||||
if offset < len(r) {
|
||||
if !yield(string(r[offset:])) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return slices.Values(parts)
|
||||
}
|
||||
|
||||
// fragment is a string fragment and their corresponding token IDs
|
||||
|
||||
@@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding {
|
||||
}
|
||||
|
||||
return NewBytePairEncoding(
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
&Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Merges: merges,
|
||||
},
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
patterns,
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
patterns: []string{
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
|
||||
},
|
||||
{
|
||||
name: "individual digits",
|
||||
patterns: []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
},
|
||||
want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenizer := NewBytePairEncoding(nil, tt.patterns...)
|
||||
if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
|
||||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,10 +54,9 @@ type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
@@ -66,7 +65,8 @@ type Batch struct {
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
|
||||
122
model/model.go
122
model/model.go
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -21,10 +21,15 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
var (
|
||||
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
ErrUnsupportedModel = errors.New("model not supported")
|
||||
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||
)
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
@@ -103,23 +108,12 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
|
||||
m, err := f(b.Config())
|
||||
m, err := modelForArch(b.Config())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := Base{b: b, config: m.Config()}
|
||||
|
||||
v := reflect.ValueOf(m)
|
||||
v.Elem().Set(populateFields(base, v.Elem()))
|
||||
return m, nil
|
||||
@@ -131,30 +125,38 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
meta, err := fsggml.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getTextProcessor(meta.KV())
|
||||
}
|
||||
|
||||
func getTextProcessor(kv fsggml.KV) (TextProcessor, error) {
|
||||
arch := kv.Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
m, err := f(kv)
|
||||
m, err := modelForArch(meta.KV())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
||||
return nil, ErrUnsupportedTokenizer
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func modelForArch(c fs.Config) (Model, error) {
|
||||
arch := c.Architecture()
|
||||
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, ErrUnsupportedModel
|
||||
}
|
||||
|
||||
return f(c)
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
@@ -170,35 +172,44 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
// make a copy
|
||||
tagsCopy := tags
|
||||
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
|
||||
tagsCopy = append(tagsCopy, ParseTags(tag))
|
||||
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||
}
|
||||
|
||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
||||
vv.Set(reflect.ValueOf(base))
|
||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
||||
var fn func([]Tag) [][]string
|
||||
fn = func(tags []Tag) (names [][]string) {
|
||||
var fn func([]Tag, string, string) [][]string
|
||||
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||
if len(tags) > 0 {
|
||||
localNames := []string{tags[0].Name}
|
||||
localNames = append(localNames, tags[0].Alternate...)
|
||||
|
||||
for _, localName := range localNames {
|
||||
fullName := []string{localName}
|
||||
nested := fn(tags[1:])
|
||||
if len(nested) > 0 {
|
||||
for _, rest := range nested {
|
||||
names = append(names, append(fullName, rest...))
|
||||
var names []string
|
||||
if tags[0].name != "" {
|
||||
for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
|
||||
names = append(names, prefix+n+suffix)
|
||||
}
|
||||
}
|
||||
childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
|
||||
if len(names) == 0 {
|
||||
// current tag has no name, use child names only
|
||||
fullNames = append(fullNames, childNames...)
|
||||
} else if len(childNames) == 0 {
|
||||
// current tag has names but no children, create branches for each name
|
||||
for _, name := range names {
|
||||
fullNames = append(fullNames, []string{name})
|
||||
}
|
||||
} else {
|
||||
// merge each name with each child
|
||||
for _, name := range names {
|
||||
for _, childName := range childNames {
|
||||
fullNames = append(fullNames, append([]string{name}, childName...))
|
||||
}
|
||||
} else {
|
||||
names = append(names, fullName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return names
|
||||
return fullNames
|
||||
}
|
||||
|
||||
names := fn(tagsCopy)
|
||||
names := fn(tagsCopy, "", "")
|
||||
for _, name := range names {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
logutil.Trace("found tensor", "", tensor)
|
||||
@@ -212,9 +223,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
for i := range vv.Len() {
|
||||
vvv := vv.Index(i)
|
||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||
setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
|
||||
setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
|
||||
} else {
|
||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
|
||||
vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -242,7 +253,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
vv = reflect.Indirect(vv)
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
@@ -253,18 +264,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
Name string
|
||||
Alternate []string
|
||||
name,
|
||||
// prefix and suffix are applied to child tags
|
||||
prefix,
|
||||
suffix string
|
||||
alternatives []string
|
||||
}
|
||||
|
||||
func ParseTags(s string) (tag Tag) {
|
||||
func parseTag(s string) (tag Tag) {
|
||||
parts := strings.Split(s, ",")
|
||||
if len(parts) > 0 {
|
||||
tag.Name = parts[0]
|
||||
tag.name = parts[0]
|
||||
|
||||
for _, part := range parts[1:] {
|
||||
if value, ok := strings.CutPrefix(part, "alt:"); ok {
|
||||
tag.Alternate = append(tag.Alternate, value)
|
||||
if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
|
||||
// elevate alternative to primary if no primary given
|
||||
tag.name = value
|
||||
slog.Warn("gguf tag has alt: but no primary name", "tag", s)
|
||||
} else if ok {
|
||||
tag.alternatives = append(tag.alternatives, value)
|
||||
}
|
||||
if value, ok := strings.CutPrefix(part, "pre:"); ok {
|
||||
tag.prefix = value
|
||||
}
|
||||
if value, ok := strings.CutPrefix(part, "suf:"); ok {
|
||||
tag.suffix = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestParseTags(t *testing.T) {
|
||||
@@ -23,14 +22,14 @@ func TestParseTags(t *testing.T) {
|
||||
{
|
||||
value: "output",
|
||||
want: Tag{
|
||||
Name: "output",
|
||||
name: "output",
|
||||
},
|
||||
},
|
||||
{
|
||||
value: "output,alt:token_embd",
|
||||
want: Tag{
|
||||
Name: "output",
|
||||
Alternate: []string{
|
||||
name: "output",
|
||||
alternatives: []string{
|
||||
"token_embd",
|
||||
},
|
||||
},
|
||||
@@ -39,8 +38,8 @@ func TestParseTags(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.value, func(t *testing.T) {
|
||||
got := ParseTags(tt.value)
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
got := parseTag(tt.value)
|
||||
if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" {
|
||||
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
@@ -126,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
Input *nn.Embedding `gguf:"input"`
|
||||
Output *nn.Linear `gguf:"output,alt:input"`
|
||||
Nested *nested `gguf:"nested"`
|
||||
Tensor ml.Tensor `gguf:"leaf,alt:tensor"`
|
||||
}
|
||||
|
||||
var m fakeModel
|
||||
@@ -134,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
names: []string{
|
||||
"input.weight",
|
||||
"nested.b.weight",
|
||||
"leaf",
|
||||
},
|
||||
}}, v.Elem()))
|
||||
|
||||
@@ -143,44 +144,115 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
Nested: &nested{
|
||||
Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}},
|
||||
},
|
||||
Tensor: &fakeTensor{Name: "leaf"},
|
||||
}, m); diff != "" {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTextProcessor(t *testing.T) {
|
||||
tp, err := getTextProcessor(fsggml.KV{})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
func TestPopulateFieldsPrefixSuffixName(t *testing.T) {
|
||||
type fakeBlock struct {
|
||||
A *nn.Linear `gguf:"a"`
|
||||
B *nn.Linear `gguf:",pre:b_"`
|
||||
C *nn.Linear `gguf:",suf:_c"`
|
||||
XY *nn.Linear `gguf:",pre:x_,suf:_y"`
|
||||
}
|
||||
|
||||
models["dummy"] = func(fs.Config) (Model, error) {
|
||||
return notTextProcessorModel{}, nil
|
||||
type fakeModel struct {
|
||||
Blocks []fakeBlock `gguf:"blk"`
|
||||
}
|
||||
tp, err = getTextProcessor(fsggml.KV{"general.architecture": "dummy"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
|
||||
m := fakeModel{
|
||||
Blocks: make([]fakeBlock, 2),
|
||||
}
|
||||
v := reflect.ValueOf(&m)
|
||||
v.Elem().Set(populateFields(Base{b: &fakeBackend{
|
||||
names: []string{
|
||||
"blk.0.a.weight",
|
||||
"blk.0.b_weight",
|
||||
"blk.0.b_bias",
|
||||
"blk.0.weight_c",
|
||||
"blk.0.x_weight_y",
|
||||
"blk.1.a.weight",
|
||||
"blk.1.b_weight",
|
||||
"blk.1.b_bias",
|
||||
"blk.1.weight_c",
|
||||
"blk.1.x_weight_y",
|
||||
},
|
||||
}}, v.Elem()))
|
||||
|
||||
if diff := cmp.Diff(fakeModel{
|
||||
Blocks: []fakeBlock{
|
||||
{
|
||||
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}},
|
||||
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}},
|
||||
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}},
|
||||
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}},
|
||||
},
|
||||
{
|
||||
A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}},
|
||||
B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}},
|
||||
C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}},
|
||||
XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}},
|
||||
},
|
||||
},
|
||||
}, m); diff != "" {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
type notTextProcessorModel struct{}
|
||||
func TestModelForArch(t *testing.T) {
|
||||
type fakeModel struct {
|
||||
Model
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
type fakeEmbeddingModel struct {
|
||||
Model
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Backend() ml.Backend {
|
||||
panic("unimplemented")
|
||||
}
|
||||
models["model"] = func(c fs.Config) (Model, error) { return fakeModel{}, nil }
|
||||
models["model_embed"] = func(c fs.Config) (Model, error) { return fakeEmbeddingModel{}, nil }
|
||||
|
||||
func (notTextProcessorModel) Config() config {
|
||||
panic("unimplemented")
|
||||
cases := []struct {
|
||||
name string
|
||||
config fs.Config
|
||||
want any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "model",
|
||||
config: fsggml.KV{
|
||||
"general.architecture": "model",
|
||||
},
|
||||
want: fakeModel{},
|
||||
},
|
||||
{
|
||||
name: "embedding",
|
||||
config: fsggml.KV{
|
||||
"general.architecture": "model",
|
||||
"model.pooling_type": uint32(1),
|
||||
},
|
||||
want: fakeEmbeddingModel{},
|
||||
},
|
||||
{
|
||||
name: "unsupported",
|
||||
config: fsggml.KV{
|
||||
"general.architecture": "unsupported",
|
||||
},
|
||||
err: ErrUnsupportedModel,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := modelForArch(tt.config)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
t.Errorf("modelForArch() returned unexpected values (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
181
model/models/bert/embed.go
Normal file
181
model/models/bert/embed.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package bert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||
|
||||
Layers []EncoderLayer `gguf:"blk"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
if m.normalize {
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||
|
||||
*MLP
|
||||
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||
}
|
||||
|
||||
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
// Attention
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// MLP
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if a.QueryNorm != nil {
|
||||
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
}
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if a.KeyNorm != nil {
|
||||
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
}
|
||||
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength int
|
||||
poolingType pooling.Type
|
||||
eps float32
|
||||
normalize bool
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
//nolint:misspell
|
||||
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||
// support it for compatibility.
|
||||
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("bert", New)
|
||||
model.Register("bert_embed", New)
|
||||
}
|
||||
324
model/models/deepseek2/model.go
Normal file
324
model/models/deepseek2/model.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package deepseek2
|
||||
|
||||
// uses deepseek 2 architecture but written based on deepseek 3 model
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
numExpertsUsed int
|
||||
numExperts int
|
||||
normTopKProb bool
|
||||
routedScalingFactor float32
|
||||
|
||||
kvLoraRank,
|
||||
qkNopeHeadDim,
|
||||
qkRopeHeadDim,
|
||||
kqNopeHeadDim,
|
||||
qkHeadDim int
|
||||
qLoraRank int
|
||||
vHeadDim int
|
||||
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength,
|
||||
originalContextLength int
|
||||
|
||||
eps,
|
||||
ropeBase,
|
||||
ropeScale float32
|
||||
kqScale float64
|
||||
}
|
||||
|
||||
func (o Options) RoPEOptions() []func(*rope.Options) {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
return []func(*rope.Options){
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
}
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Q *nn.Linear `gguf:"attn_q"`
|
||||
|
||||
QA *nn.Linear `gguf:"attn_q_a"`
|
||||
QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"`
|
||||
QB *nn.Linear `gguf:"attn_q_b"`
|
||||
|
||||
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
|
||||
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
|
||||
KVB *nn.Linear `gguf:"attn_kv_b"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 { // nil {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
|
||||
qPass := query.View(ctx, 0,
|
||||
opts.qkNopeHeadDim, query.Stride(1),
|
||||
query.Dim(1), query.Stride(2),
|
||||
query.Dim(2))
|
||||
|
||||
qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0),
|
||||
opts.qkRopeHeadDim, query.Stride(1),
|
||||
query.Dim(1), query.Stride(2),
|
||||
query.Dim(2))
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
|
||||
kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1))
|
||||
kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0),
|
||||
opts.qkRopeHeadDim, compressedKV.Stride(1),
|
||||
1, compressedKV.Stride(1),
|
||||
compressedKV.Dim(1))
|
||||
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2))
|
||||
value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0),
|
||||
opts.vHeadDim, kv.Stride(1),
|
||||
kv.Dim(1), kv.Stride(2),
|
||||
kv.Dim(2)).Contiguous(ctx)
|
||||
|
||||
qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
|
||||
kRot = kRot.Repeat(ctx, 1, qPass.Dim(1))
|
||||
|
||||
query = qRot.Concat(ctx, qPass, 0)
|
||||
key := kRot.Concat(ctx, kPass, 0)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
SharedExpert *dense `gguf:",suf:_shexp"`
|
||||
ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"`
|
||||
}
|
||||
|
||||
func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
hiddenStates = hiddenStates.SILU(ctx, upStates)
|
||||
|
||||
experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices)
|
||||
experts = experts.Mul(ctx, topKWeights)
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
return nextStates
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
return topKIndices
|
||||
}
|
||||
|
||||
func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
residuals := hiddenStates
|
||||
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
scores := routerLogits.Sigmoid(ctx)
|
||||
topKIndices := moe.topKIndices(ctx, scores, opts)
|
||||
topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices)
|
||||
|
||||
if opts.normTopKProb {
|
||||
topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx))
|
||||
topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
|
||||
}
|
||||
|
||||
topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor))
|
||||
hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts)
|
||||
sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts)
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, sharedExpertResult)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type dense struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *Attention
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
|
||||
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
|
||||
for i := range layers {
|
||||
if i < firstDenseLayerIndex {
|
||||
layers[i].MLP = &dense{}
|
||||
} else {
|
||||
layers[i].MLP = &sparse{}
|
||||
}
|
||||
}
|
||||
|
||||
mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor"))))
|
||||
kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length")))
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
// Split regex into multiple parts (according to DeepSeek3's regex)
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("expert_weights_norm", true),
|
||||
|
||||
qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal,
|
||||
kvLoraRank: int(c.Uint("attention.kv_lora_rank")),
|
||||
qkHeadDim: int(c.Uint("attention.key_length")),
|
||||
vHeadDim: int(c.Uint("attention.value_length")),
|
||||
qkRopeHeadDim: int(c.Uint("rope.dimension_count")),
|
||||
qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")),
|
||||
|
||||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("deepseek2", New)
|
||||
}
|
||||
@@ -24,7 +24,7 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -40,7 +40,7 @@ const (
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
attnValLen: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 10000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping"),
|
||||
},
|
||||
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -138,7 +138,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
|
||||
@@ -1,49 +1,38 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
PoolingType uint32
|
||||
poolingType pooling.Type
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
batch.Outputs = batch.Positions // return all positions
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
switch m.PoolingType {
|
||||
case 0: // None
|
||||
case 1: // Mean
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
default:
|
||||
return nil, errors.New("unsupported pooling type")
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
PoolingType: c.Uint("pooling_type", 0),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
ropeScale: 1,
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -84,7 +87,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -95,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -123,7 +126,7 @@ type TextMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -161,7 +164,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
@@ -194,7 +196,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
}
|
||||
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
@@ -95,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
|
||||
}
|
||||
|
||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||
active = active.GELU(ctx)
|
||||
active = active.Mul(ctx, perLayerInput)
|
||||
active = active.GELU(ctx, perLayerInput)
|
||||
|
||||
active = d.PerLayerProjection.Forward(ctx, active)
|
||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||
@@ -257,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
|
||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
||||
hiddenStates = hiddenStates.GELU(ctx, upStates)
|
||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
@@ -350,7 +349,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
||||
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
||||
|
||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if i == len(m.TransformerBlocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
|
||||
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||
}
|
||||
|
||||
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
|
||||
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
|
||||
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
@@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
m := Transformer{
|
||||
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer",
|
||||
strings.Join([]string{
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`\p{N}{1,3}`,
|
||||
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
|
||||
`\s*[\r\n]+`,
|
||||
`\s+(?!\S)`,
|
||||
`\s+`,
|
||||
}, "|"),
|
||||
),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
strings.Join([]string{
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
|
||||
`\p{N}{1,3}`,
|
||||
` ?[^\s\p{L}\p{N}]+[\r\n/]*`,
|
||||
`\s*[\r\n]+`,
|
||||
`\s+(?!\S)`,
|
||||
`\s+`,
|
||||
}, "|"),
|
||||
),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
|
||||
@@ -2,7 +2,6 @@ package llama
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -23,51 +22,80 @@ type Options struct {
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*Options
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// This model currently only supports the gpt2 tokenizer
|
||||
if c.String("tokenizer.ggml.model") == "llama" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
||||
if c.Uint("expert_count") > 0 {
|
||||
// TODO: support mixtures of experts
|
||||
return nil, model.ErrUnsupportedModel
|
||||
}
|
||||
// Best effort detection of library/deepseek-coder model(s) which are incompatible
|
||||
if c.String("general.name") == "deepseek-ai" {
|
||||
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
|
||||
}
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
|
||||
var processor model.TextProcessor
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: &Options{
|
||||
}
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
var pretokenizers []string
|
||||
switch c.String("tokenizer.ggml.pre") {
|
||||
case "default":
|
||||
// no-op use the default bpe pretokenizer
|
||||
case "qwen2":
|
||||
pretokenizers = []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
case "refact":
|
||||
pretokenizers = []string{
|
||||
`\p{N}`,
|
||||
`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`,
|
||||
}
|
||||
case "tekken":
|
||||
pretokenizers = []string{
|
||||
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
default:
|
||||
// use a llama-style pretokenizer
|
||||
pretokenizers = []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
case "llama":
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeBase: c.Float("rope.freq_base", 1e5),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -98,8 +126,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
@@ -108,7 +136,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -118,7 +146,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -160,10 +188,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
@@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer",
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
@@ -176,9 +175,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -33,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if useRope {
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
|
||||
if opts.useQKNorm {
|
||||
@@ -58,14 +58,14 @@ type TextMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextExperts struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||
|
||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
upStates := e.Up.Forward(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
|
||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
@@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
||||
return nextStates
|
||||
}
|
||||
|
||||
// TextSharedExpert is TextMLP with different tensor names
|
||||
type TextSharedExpert struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
|
||||
Up *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
Down *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextMOE struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Experts *TextExperts
|
||||
SharedExpert *TextSharedExpert
|
||||
SharedExpert *TextMLP `gguf:",suf:_shexp"`
|
||||
}
|
||||
|
||||
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
@@ -196,7 +184,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
|
||||
noRopeInterval: int(c.Uint("no_rope_interval", 4)),
|
||||
@@ -248,5 +236,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil)
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -40,11 +40,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -55,7 +55,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
@@ -65,7 +65,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ type VisionMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ const (
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -26,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -45,7 +45,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
@@ -58,7 +58,7 @@ type TextMLP struct {
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/bert"
|
||||
_ "github.com/ollama/ollama/model/models/deepseek2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
|
||||
@@ -43,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
value := attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
@@ -59,7 +59,7 @@ type MLP struct {
|
||||
}
|
||||
|
||||
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||
@@ -124,7 +124,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
@@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
@@ -160,7 +160,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
TextModel: NewTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -38,7 +38,7 @@ func NewTextModel(c fs.Config) *TextModel {
|
||||
originalContextLength: int(c.Uint("context_length", 128000)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -60,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -78,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
@@ -90,7 +90,7 @@ type MLP struct {
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Apply SwiGLU activation gating
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
// Project back to hidden dimension
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
@@ -100,8 +100,7 @@ type VisionMLP struct {
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
|
||||
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
73
model/models/qwen3/embed.go
Normal file
73
model/models/qwen3/embed.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package qwen3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*Model
|
||||
poolingType pooling.Type
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates, err := m.forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbed(c fs.Config) (model.Model, error) {
|
||||
layers := make([]Layer, c.Uint("block_count"))
|
||||
for i := range layers {
|
||||
layers[i].MLP = &dense{}
|
||||
}
|
||||
m := embedModel{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Model: &Model{
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
},
|
||||
},
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
return &m, nil
|
||||
}
|
||||
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
@@ -52,8 +52,8 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
@@ -65,10 +65,10 @@ type MLP interface {
|
||||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
|
||||
|
||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = hiddenStates.SILU(ctx)
|
||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||
|
||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
@@ -111,7 +107,8 @@ type dense struct {
|
||||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
|
||||
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
@@ -154,29 +151,39 @@ type Model struct {
|
||||
*Options
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates, err := m.forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
if m.Cache != nil {
|
||||
m.Cache.SetLayer(i)
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
return m.OutputNorm.Forward(ctx, hiddenStates, m.eps), nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
var _ model.Model = (*Model)(nil)
|
||||
@@ -193,7 +200,6 @@ func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
@@ -206,6 +212,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||||
),
|
||||
Layers: layers,
|
||||
Options: &Options{
|
||||
@@ -216,7 +223,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
normTopKProb: c.Bool("norm_top_k_prob", true),
|
||||
@@ -230,4 +237,5 @@ func New(c fs.Config) (model.Model, error) {
|
||||
func init() {
|
||||
model.Register("qwen3", New)
|
||||
model.Register("qwen3moe", New)
|
||||
model.Register("qwen3_embed", newEmbed)
|
||||
}
|
||||
|
||||
49
model/parsers/parsers.go
Normal file
49
model/parsers/parsers.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
)
|
||||
|
||||
type Parser interface {
|
||||
// Init initializes the parser with tools and optional last message for chat prefill
|
||||
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
|
||||
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
|
||||
// Add processes streamed content and returns parsed content, thinking, and tool calls
|
||||
// The done flag indicates if this is the last chunk (used for draining accumulators)
|
||||
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
|
||||
HasToolSupport() bool
|
||||
HasThinkingSupport() bool
|
||||
}
|
||||
|
||||
func ParserForName(name string) Parser {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
return harmony.NewHarmonyMessageHandler()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type PassthroughParser struct{}
|
||||
|
||||
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
return tools // passthrough doesn't modify tools
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
return s, "", nil, nil
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) HasToolSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
463
model/parsers/qwen3coder.go
Normal file
463
model/parsers/qwen3coder.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwenParserState int
|
||||
|
||||
const (
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
const (
|
||||
qwenParserState_LookingForToolStart qwenParserState = iota
|
||||
qwenParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
p.tools = tools
|
||||
return tools // Qwen doesn't modify tools
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.acc.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var sb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwenEventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here. See the note below about
|
||||
// `qwenEvent`s for more details
|
||||
sb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
|
||||
var all []qwenEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwenEvent
|
||||
events, keepLooping = eat(p)
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// we use some internal event types in order to communicate between `Add` and
|
||||
// `eat`. We do this to support interleaving content and parallel tool calls in
|
||||
// the parser, even though qwen3-coder isn't supposed to do this. Our API
|
||||
// doesn't currently support models outputting multiple messages in a turn, so
|
||||
// we wouldn't be able to represent it yet, but there's no reason to prevent the
|
||||
// parser from supporting it, especially for future models if they end up using
|
||||
// a similar format.
|
||||
type qwenEvent interface {
|
||||
isQwenEvent()
|
||||
}
|
||||
|
||||
type qwenEventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
type qwenEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwenEventContent) isQwenEvent() {}
|
||||
func (qwenEventRawToolCall) isQwenEvent() {}
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. If the parser transitions to another
|
||||
// state, it may have additional events to emit on the next call, which is what
|
||||
// the second return value indicates
|
||||
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
var events []qwenEvent
|
||||
|
||||
switch p.state {
|
||||
case qwenParserState_LookingForToolStart:
|
||||
if strings.Contains(p.acc.String(), toolOpenTag) {
|
||||
// we found a full tool open tag, so we can emit the content before the
|
||||
// tag, being sure to trim any trailing whitespace
|
||||
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(after)
|
||||
p.state = qwenParserState_CollectingToolContent
|
||||
return events, true
|
||||
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
|
||||
// we found a partial tool open tag, so we can emit the unambiguous part,
|
||||
// which is the (trailing-whitespace trimmed) content before the partial
|
||||
// tool open tag
|
||||
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
unambiguous := p.acc.String()[:ambiguousStart]
|
||||
ambiguous := p.acc.String()[ambiguousStart:]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(ambiguous)
|
||||
events = append(events, qwenEventContent{content: unambiguous})
|
||||
return events, false
|
||||
} else {
|
||||
// we found content that is entirely not a tool call. We should withhold
|
||||
// any trailing whitespace in case this is the end of the content
|
||||
whitespaceLen := trailingWhitespaceLen(p.acc.String())
|
||||
ambiguousStart := len(p.acc.String()) - whitespaceLen
|
||||
unambiguous := p.acc.String()[:ambiguousStart]
|
||||
ambiguous := p.acc.String()[ambiguousStart:]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwenEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
case qwenParserState_CollectingToolContent:
|
||||
if strings.Contains(p.acc.String(), toolCloseTag) {
|
||||
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
// remove any whitespace between the tool call and any content after it
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(after)
|
||||
events = append(events, qwenEventRawToolCall{raw: before})
|
||||
p.state = qwenParserState_LookingForToolStart
|
||||
return events, true
|
||||
} else {
|
||||
// note that we don't need to check the overlap here because we only plan
|
||||
// on parsing the tool call once we see the full closing tag. We don't
|
||||
// stream back the unparsed tool content, so there's no need to be eager
|
||||
// here
|
||||
return events, false
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
remaining := s
|
||||
total := 0
|
||||
for len(remaining) > 0 {
|
||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
break
|
||||
}
|
||||
if !unicode.IsSpace(r) {
|
||||
break
|
||||
}
|
||||
total += size
|
||||
remaining = remaining[:len(remaining)-size]
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
Parameters []XMLParameter `xml:"parameter"`
|
||||
}
|
||||
|
||||
type XMLParameter struct {
|
||||
Name string `xml:"name,attr"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// parseToolCall parses a raw tool call string into an api.ToolCall.
|
||||
// The raw string follows an xml-like format, here's an example:
|
||||
//
|
||||
// <function=get_current_temperature>
|
||||
// <parameter=location>
|
||||
// San Francisco
|
||||
// </parameter>
|
||||
// <parameter=unit>
|
||||
// celsius
|
||||
// </parameter>
|
||||
// </function>
|
||||
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
xmlString := transformToXML(raw.raw)
|
||||
|
||||
var functionCall XMLFunctionCall
|
||||
err := xml.Unmarshal([]byte(xmlString), &functionCall)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
toolCall.Function = api.ToolCallFunction{
|
||||
Name: functionCall.Name,
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionCall.Name {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
for _, parameter := range functionCall.Parameters {
|
||||
// Look up the parameter type if we found the tool
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
|
||||
//
|
||||
// For union types (multiple types in PropertyType, which we support but doesn't
|
||||
// seem as though the reference parser does type coercion with those types in
|
||||
// mind) we use a type precedence approach:
|
||||
// 1. null - checked first regardless of declared types (matches reference implementation)
|
||||
// 2. boolean - only "true"/"false" are valid booleans
|
||||
// 3. integer - must parse as a whole number
|
||||
// 4. number - must parse as numeric (returns int if no decimal part)
|
||||
// 5. array - must parse as valid JSON array
|
||||
// 6. object - must parse as valid JSON object
|
||||
// 7. string - always succeeds (least specific type)
|
||||
//
|
||||
// This precedence ensures we return the most specific type that successfully parses,
|
||||
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
|
||||
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
|
||||
func parseValue(raw string, paramType api.PropertyType) any {
|
||||
// first remove a single leading newlines, and a single trailing newline (if
|
||||
// they exist). This follows the reference implementation
|
||||
raw = strings.TrimPrefix(raw, "\n")
|
||||
raw = strings.TrimSuffix(raw, "\n")
|
||||
|
||||
// Check for null first (case-insensitive) - this takes precedence over any type
|
||||
if strings.ToLower(raw) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If no type is specified, default to string
|
||||
if len(paramType) == 0 {
|
||||
return raw
|
||||
}
|
||||
|
||||
// Check if any of the specified types match, using type precedence
|
||||
// Order: boolean -> integer -> number -> array -> object -> string
|
||||
typeSet := make(map[string]bool)
|
||||
for _, t := range paramType {
|
||||
typeSet[t] = true
|
||||
}
|
||||
|
||||
// Try boolean first (most restrictive)
|
||||
if typeSet["boolean"] {
|
||||
lower := strings.ToLower(raw)
|
||||
switch lower {
|
||||
case "true":
|
||||
return true
|
||||
case "false":
|
||||
return false
|
||||
}
|
||||
// If not a valid boolean but boolean is the only type, return false (matching reference)
|
||||
if len(paramType) == 1 {
|
||||
return false
|
||||
}
|
||||
// Otherwise try other types
|
||||
}
|
||||
|
||||
// Try integer
|
||||
if typeSet["integer"] {
|
||||
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
// Return as int if it fits in int32, otherwise int64
|
||||
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||
return int(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
// If integer is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try number (float)
|
||||
if typeSet["number"] {
|
||||
if f, err := strconv.ParseFloat(raw, 64); err == nil {
|
||||
// If the number has no decimal part, return as int (matching reference)
|
||||
if f == math.Trunc(f) {
|
||||
i := int64(f)
|
||||
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||
return int(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
return f
|
||||
}
|
||||
// If number is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try array
|
||||
if typeSet["array"] {
|
||||
var arr []any
|
||||
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
|
||||
return arr
|
||||
}
|
||||
// If array is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try object
|
||||
if typeSet["object"] {
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
||||
return obj
|
||||
}
|
||||
// If object is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// String always succeeds (or if "string" is in the type set)
|
||||
if typeSet["string"] {
|
||||
return raw
|
||||
}
|
||||
|
||||
// If we get here, none of the types matched and string wasn't an option
|
||||
// We return string as a fallback. The reference implementation will attempt
|
||||
// to parse the value as a python literal, but we purposefully don't support
|
||||
// that
|
||||
return raw
|
||||
}
|
||||
|
||||
var (
|
||||
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
|
||||
)
|
||||
|
||||
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||
// xml so that it can be parsed by any xml parser
|
||||
func transformToXML(raw string) string {
|
||||
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||
// care to properly escape the string that becomes the attribute value
|
||||
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||
tag := groups[1]
|
||||
var escapedValue strings.Builder
|
||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||
})
|
||||
|
||||
// Walk the resulting string, escaping any character data that sits between the
|
||||
// xml tags we just emitted
|
||||
var out strings.Builder
|
||||
lastIdx := 0
|
||||
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
|
||||
if loc[0] > lastIdx {
|
||||
escapeTextNode(&out, transformed[lastIdx:loc[0]])
|
||||
}
|
||||
out.WriteString(transformed[loc[0]:loc[1]])
|
||||
lastIdx = loc[1]
|
||||
}
|
||||
if lastIdx < len(transformed) {
|
||||
escapeTextNode(&out, transformed[lastIdx:])
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
// escapeTextNode escapes XML character data without altering other characters
|
||||
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
|
||||
func escapeTextNode(sb *strings.Builder, s string) {
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '&':
|
||||
sb.WriteString("&")
|
||||
case '<':
|
||||
sb.WriteString("<")
|
||||
case '>':
|
||||
sb.WriteString(">")
|
||||
default:
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
1095
model/parsers/qwen3coder_test.go
Normal file
1095
model/parsers/qwen3coder_test.go
Normal file
File diff suppressed because it is too large
Load Diff
217
model/renderers/qwen3coder.go
Normal file
217
model/renderers/qwen3coder.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var (
|
||||
imStartTag = "<|im_start|>"
|
||||
imEndTag = "<|im_end|>"
|
||||
)
|
||||
|
||||
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
|
||||
// This follows the same approach from the reference implementation, which gives
|
||||
// a particular key ordering
|
||||
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for key, value := range m {
|
||||
if handledKeys[key] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if value is a map or array (needs JSON serialization)
|
||||
switch v := value.(type) {
|
||||
case map[string]any, []any:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
// TODO(drifkin): it would be nice to format the JSON here similarly to
|
||||
// python's default json.dumps behavior (spaces after commas and colons).
|
||||
// This would let us be byte-for-byte compatible with the reference
|
||||
// implementation for most common inputs
|
||||
jsonStr := string(jsonBytes)
|
||||
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
// Simple types, convert to string
|
||||
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// filter out system messages and choose the first (if any) to win
|
||||
var systemMessage string
|
||||
var filteredMessages []api.Message
|
||||
for _, message := range messages {
|
||||
if message.Role != "system" {
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
continue
|
||||
}
|
||||
|
||||
if systemMessage == "" {
|
||||
systemMessage = message.Content
|
||||
}
|
||||
}
|
||||
|
||||
if systemMessage != "" || len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
|
||||
// if we have tools but no system message, match the reference implementation by providing a default system message
|
||||
if systemMessage == "" {
|
||||
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
|
||||
}
|
||||
|
||||
sb.WriteString(systemMessage)
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
|
||||
sb.WriteString("<tools>")
|
||||
for _, tool := range tools {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString("<function>\n")
|
||||
sb.WriteString("<name>" + tool.Function.Name + "</name>")
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
|
||||
}
|
||||
sb.WriteString("\n<parameters>")
|
||||
|
||||
for name, prop := range tool.Function.Parameters.Properties {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + name + "</name>")
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
// TODO(!!!)(drifkin): we should match the reference implementation for
|
||||
// more complex types here instead of using this format
|
||||
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
|
||||
}
|
||||
|
||||
if prop.Description != "" {
|
||||
sb.WriteString("\n<description>" + prop.Description + "</description>")
|
||||
}
|
||||
|
||||
// Render any additional keys not already handled
|
||||
handledKeys := map[string]bool{
|
||||
"type": true,
|
||||
"description": true,
|
||||
}
|
||||
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
|
||||
|
||||
sb.WriteString("\n</parameter>")
|
||||
}
|
||||
|
||||
// Render extra keys for parameters (everything except 'type' and 'properties')
|
||||
paramHandledKeys := map[string]bool{
|
||||
"type": true,
|
||||
"properties": true,
|
||||
}
|
||||
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
|
||||
|
||||
sb.WriteString("\n</parameters>")
|
||||
sb.WriteString("\n</function>")
|
||||
}
|
||||
sb.WriteString("\n</tools>")
|
||||
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
|
||||
}
|
||||
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content + "\n")
|
||||
}
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||
for name, value := range toolCall.Function.Arguments {
|
||||
valueStr := formatToolCallArgument(value)
|
||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||
}
|
||||
sb.WriteString("\n</function>\n</tool_call>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
sb.WriteString(message.Content)
|
||||
if !prefill {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
// consecutive tool responses should share a single `<im_start>user`, but
|
||||
// have their own <tool_response> tags
|
||||
|
||||
// only start a new user block if this is the first tool response
|
||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user\n")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
|
||||
// close the user block only if this is the last tool response
|
||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
default:
|
||||
sb.WriteString(imStartTag + message.Role + "\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func formatToolCallArgument(value any) string {
|
||||
if value == nil {
|
||||
return "null"
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
}
|
||||
|
||||
if reflect.TypeOf(value) != nil {
|
||||
kind := reflect.TypeOf(value).Kind()
|
||||
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
|
||||
if marshalled, err := json.Marshal(value); err == nil {
|
||||
return string(marshalled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
338
model/renderers/qwen3coder_test.go
Normal file
338
model/renderers/qwen3coder_test.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3CoderRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello, how are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "with tools and response",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in San Francisco?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll check the weather in San Francisco for you.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
|
||||
{Role: "user", Content: "That sounds nice! What about New York?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Required: []string{"unit"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||
// TODO(drifkin): add multiple params back once we have predictable
|
||||
// order via some sort of ordered map type (see
|
||||
// <https://github.com/ollama/ollama/issues/12244>)
|
||||
/*
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||
*/
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>get_weather</name>
|
||||
<description>Get the current weather in a given location</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>The unit of temperature</description>
|
||||
<enum>["celsius","fahrenheit"]</enum>
|
||||
</parameter>
|
||||
<required>["unit"]</required>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
<|im_start|>user
|
||||
What is the weather like in San Francisco?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll check the weather in San Francisco for you.
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>user
|
||||
That sounds nice! What about New York?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||
}}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||
}}}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>double</name>
|
||||
<description>Double a number</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>number</name>
|
||||
<type>string</type>
|
||||
<description>The number to double</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</function>
|
||||
<function>
|
||||
<name>triple</name>
|
||||
<description>Triple a number</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>number</name>
|
||||
<type>string</type>
|
||||
<description>The number to triple</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
<|im_start|>user
|
||||
call double(1) and triple(2)<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll call double(1) and triple(2) for you.
|
||||
|
||||
<tool_call>
|
||||
<function=double>
|
||||
<parameter=number>
|
||||
1
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=triple>
|
||||
<parameter=number>
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"number": 2}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"number": 6}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "prefill",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Tell me something interesting."},
|
||||
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Tell me something interesting.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll tell you something interesting about cats`,
|
||||
},
|
||||
{
|
||||
name: "complex tool call arguments should remain json encoded",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "call tool"},
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: map[string]any{
|
||||
"payload": map[string]any{"foo": "bar"},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||
},
|
||||
expected: `<|im_start|>user
|
||||
call tool<|im_end|>
|
||||
<|im_start|>assistant
|
||||
|
||||
<tool_call>
|
||||
<function=echo>
|
||||
<parameter=payload>
|
||||
{"foo":"bar"}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"payload": {"foo": "bar"}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallArgument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
arg: "foo",
|
||||
// notice no quotes around the string
|
||||
expected: "foo",
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
arg: map[string]any{"foo": "bar"},
|
||||
expected: "{\"foo\":\"bar\"}",
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
arg: 1,
|
||||
expected: "1",
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
arg: true,
|
||||
expected: "true",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := formatToolCallArgument(tt.arg)
|
||||
if got != tt.expected {
|
||||
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
26
model/renderers/renderer.go
Normal file
26
model/renderers/renderer.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
|
||||
|
||||
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
renderer := rendererForName(name)
|
||||
if renderer == nil {
|
||||
return "", fmt.Errorf("unknown renderer %q", name)
|
||||
}
|
||||
return renderer(msgs, tools, think)
|
||||
}
|
||||
|
||||
func rendererForName(name string) rendererFunc {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
return Qwen3CoderRenderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -12,18 +12,18 @@ import (
|
||||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePieceModel struct {
|
||||
type SentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePieceModel{
|
||||
return SentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
|
||||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
}
|
||||
}
|
||||
|
||||
return NewSentencePieceModel(&v)
|
||||
return NewSentencePiece(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePieceModel(vocab)
|
||||
spm := NewSentencePiece(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
167
model/wordpiece.go
Normal file
167
model/wordpiece.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||
// this differs from original word piece which uses "##" to indicate subwords.
|
||||
const ggmlPrefix = "▁"
|
||||
|
||||
var wordPieceReplacer = strings.NewReplacer(
|
||||
" .", ".",
|
||||
" ?", "?",
|
||||
" !", "!",
|
||||
" ,", ",",
|
||||
" ' ", "'",
|
||||
" n't", "n't",
|
||||
" 'm", "'m",
|
||||
" do not", " don't",
|
||||
" 's", "'s",
|
||||
" 've", "'ve",
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||
return "", fmt.Errorf("invalid token id: %d", id)
|
||||
}
|
||||
|
||||
var separator string
|
||||
piece := wpm.vocab.Values[id]
|
||||
if i > 0 &&
|
||||
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||
separator = " "
|
||||
}
|
||||
|
||||
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// words splits a string into words, treating CJK characters as separate words.
|
||||
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
runes := make([]rune, 0, len(s)*3)
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF,
|
||||
r >= 0x3400 && r <= 0x4DBF,
|
||||
r >= 0x20000 && r <= 0x2A6DF,
|
||||
r >= 0x2A700 && r <= 0x2B73F,
|
||||
r >= 0x2B740 && r <= 0x2B81F,
|
||||
r >= 0x2B820 && r <= 0x2CEAF,
|
||||
r >= 0xF900 && r <= 0xFAFF,
|
||||
r >= 0x2F800 && r <= 0x2FA1F:
|
||||
runes = append(runes, ' ', r, ' ')
|
||||
default:
|
||||
runes = append(runes, r)
|
||||
}
|
||||
}
|
||||
|
||||
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||
// split on but keep punctuation
|
||||
var start int
|
||||
for start < len(w) {
|
||||
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||
if end < 0 {
|
||||
end = len(w) - start
|
||||
} else if end == 0 {
|
||||
end = 1
|
||||
}
|
||||
|
||||
if !yield(w[start : start+end]) {
|
||||
return
|
||||
}
|
||||
|
||||
start += end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
// TODO: use [UNK] from config
|
||||
unk := wpm.vocab.Encode("[UNK]")
|
||||
for word := range wpm.words(s) {
|
||||
var start int
|
||||
var pieces []int32
|
||||
for start < len(word) {
|
||||
end := len(word)
|
||||
|
||||
var piece int32
|
||||
for start < end {
|
||||
subword := word[start:end]
|
||||
if start == 0 {
|
||||
subword = ggmlPrefix + subword
|
||||
}
|
||||
|
||||
// TODO: some models might not want [ToLower]
|
||||
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||
if piece >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
end--
|
||||
}
|
||||
|
||||
if piece < 0 {
|
||||
// Unknown token
|
||||
pieces = pieces[:0]
|
||||
break
|
||||
}
|
||||
|
||||
pieces = append(pieces, piece)
|
||||
start = end
|
||||
}
|
||||
|
||||
if len(pieces) > 0 {
|
||||
ids = append(ids, pieces...)
|
||||
} else {
|
||||
ids = append(ids, unk)
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = wpm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||
return WordPiece{
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
51
model/wordpiece_test.go
Normal file
51
model/wordpiece_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
})
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
@@ -105,16 +105,18 @@ type ChatCompletionRequest struct {
|
||||
Tools []api.Tool `json:"tools"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
type ChatCompletion struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionChunk struct {
|
||||
@@ -141,6 +143,7 @@ type CompletionRequest struct {
|
||||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
@@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: toUsage(r),
|
||||
}}, Usage: toUsage(r),
|
||||
DebugInfo: r.DebugInfo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
}
|
||||
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||
req.System = c.Args
|
||||
case "license":
|
||||
licenses = append(licenses, c.Args)
|
||||
case "renderer":
|
||||
req.Renderer = c.Args
|
||||
case "parser":
|
||||
req.Parser = c.Args
|
||||
case "message":
|
||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||
@@ -320,7 +324,7 @@ func (c Command) String() string {
|
||||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter":
|
||||
case "license", "template", "system", "adapter", "renderer", "parser":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
@@ -346,7 +350,7 @@ const (
|
||||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
|
||||
)
|
||||
|
||||
type ParserError struct {
|
||||
@@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool {
|
||||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "parameter", "message":
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -198,6 +198,34 @@ BADCOMMAND param1 value1
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFileRenderer(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
RENDERER renderer1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
modelfile, err := ParseFile(reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
|
||||
}
|
||||
|
||||
func TestParseFileParser(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
PARSER parser1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
modelfile, err := ParseFile(reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
|
||||
}
|
||||
|
||||
func TestParseFileMessages(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
)
|
||||
|
||||
type TokenParserType int
|
||||
|
||||
const (
|
||||
TokenParserTypeDefault TokenParserType = iota
|
||||
TokenParserTypeHarmony
|
||||
)
|
||||
|
||||
type TokenParser struct {
|
||||
messageHandler MessageHandler
|
||||
parserEngine ParserInternals
|
||||
toolParser ToolParser
|
||||
lastToken string
|
||||
tokenRepeat int
|
||||
repeatLimit int
|
||||
}
|
||||
|
||||
const defaultTokenRepeatLimit = 30
|
||||
|
||||
type MessageHandler interface {
|
||||
AddContent(token string) (content, thinking string, toolContent string)
|
||||
}
|
||||
|
||||
type ParserInternals interface {
|
||||
AddImplicitStartOrPrefill(prefillString string)
|
||||
}
|
||||
|
||||
type ToolParser interface {
|
||||
Add(token string)
|
||||
Drain() (toolName *string, toolContent string)
|
||||
}
|
||||
|
||||
// Default implementation for the TokenParser interface as a no-op passthrough
|
||||
type defaultMessageHandler struct{}
|
||||
|
||||
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
|
||||
return token, "", ""
|
||||
}
|
||||
|
||||
type defaultEngine struct{}
|
||||
|
||||
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
|
||||
|
||||
type defaultToolParser struct{}
|
||||
|
||||
func (defaultToolParser) Add(token string) {}
|
||||
|
||||
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
|
||||
|
||||
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
|
||||
switch parserType {
|
||||
case TokenParserTypeHarmony:
|
||||
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
|
||||
return TokenParser{
|
||||
messageHandler: harmonyMessageHandler,
|
||||
parserEngine: harmonyMessageHandler.HarmonyParser,
|
||||
toolParser: harmonyMessageHandler.ToolParser,
|
||||
repeatLimit: defaultTokenRepeatLimit,
|
||||
}
|
||||
|
||||
default:
|
||||
return TokenParser{
|
||||
messageHandler: defaultMessageHandler{},
|
||||
parserEngine: defaultEngine{},
|
||||
toolParser: defaultToolParser{},
|
||||
repeatLimit: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TokenParser) AddContent(token string) (string, string, error) {
|
||||
if p.repeatLimitReached(token) {
|
||||
return "", "", errors.New("token repeat limit reached")
|
||||
}
|
||||
content, thinking, toolContent := p.messageHandler.AddContent(token)
|
||||
p.toolParser.Add(toolContent)
|
||||
return content, thinking, nil
|
||||
}
|
||||
|
||||
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
|
||||
func (p *TokenParser) repeatLimitReached(token string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == p.lastToken {
|
||||
p.tokenRepeat++
|
||||
} else {
|
||||
p.tokenRepeat = 0
|
||||
}
|
||||
p.lastToken = trimmed
|
||||
|
||||
return p.tokenRepeat >= p.repeatLimit
|
||||
}
|
||||
|
||||
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
|
||||
func (p *TokenParser) Drain() []api.ToolCall {
|
||||
toolName, toolContent := p.toolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
return nil
|
||||
}
|
||||
return []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -204,13 +204,8 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard < 0 {
|
||||
discard = 0
|
||||
}
|
||||
|
||||
return discard
|
||||
return max(targetFree-currentFree, 0)
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
|
||||
@@ -242,13 +242,8 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard < 0 {
|
||||
discard = 0
|
||||
}
|
||||
|
||||
return discard
|
||||
return max(targetFree-currentFree, 0)
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -32,9 +31,9 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
|
||||
@@ -406,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
||||
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||
|
||||
var activeBatch batchState
|
||||
for {
|
||||
@@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batchOutputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
seq.iBatch = len(batchOutputs)
|
||||
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
@@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
|
||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to build graph: %w", err)
|
||||
@@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
@@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
@@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
var thinking string
|
||||
var err error
|
||||
content, thinking, err = tokenParser.AddContent(content)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
Thinking: thinking,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
@@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
toolCalls := tokenParser.Drain()
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
ToolCalls: toolCalls,
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
@@ -913,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
||||
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
@@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
|
||||
@@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
|
||||
merges := make([]string, 0, 1)
|
||||
// Only need vocab for Grammar Test
|
||||
return model.NewBytePairEncoding(
|
||||
``,
|
||||
&model.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: make([]int32, len(vocab)),
|
||||
|
||||
@@ -16,6 +16,7 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
||||
--build-arg=OLLAMA_FAST_BUILD \
|
||||
--build-arg=CUSTOM_CPU_FLAGS \
|
||||
--build-arg=GPU_RUNNER_CPU_FLAGS \
|
||||
--build-arg=PARALLEL \
|
||||
--build-arg=AMDGPU_TARGETS"
|
||||
|
||||
echo "Building Ollama"
|
||||
|
||||
150
server/create.go
150
server/create.go
@@ -10,8 +10,11 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -39,6 +42,14 @@ var (
|
||||
)
|
||||
|
||||
func (s *Server) CreateHandler(c *gin.Context) {
|
||||
config := &ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
var r api.CreateRequest
|
||||
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
@@ -48,6 +59,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
config.Renderer = r.Renderer
|
||||
config.Parser = r.Parser
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
@@ -77,20 +91,34 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
oldManifest, _ := ParseNamedManifest(name)
|
||||
|
||||
var baseLayers []*layerGGML
|
||||
var err error
|
||||
var remote bool
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name")
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromName := model.ParseName(r.From)
|
||||
if !fromName.IsValid() {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
if r.RemoteHost != "" {
|
||||
ru, err := remoteURL(r.RemoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
config.RemoteModel = r.From
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
baseLayers, err = parseFromModel(ctx, fromName, fn)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}
|
||||
} else if r.Files != nil {
|
||||
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
|
||||
@@ -110,7 +138,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
var adapterLayers []*layerGGML
|
||||
if r.Adapters != nil {
|
||||
if !remote && r.Adapters != nil {
|
||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||
if err != nil {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||
@@ -128,7 +156,56 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
baseLayers = append(baseLayers, adapterLayers...)
|
||||
}
|
||||
|
||||
if err := createModel(r, name, baseLayers, fn); err != nil {
|
||||
// Info is not currently exposed by Modelfiles, but allows overriding various
|
||||
// config values
|
||||
if r.Info != nil {
|
||||
caps, ok := r.Info["capabilities"]
|
||||
if ok {
|
||||
switch tcaps := caps.(type) {
|
||||
case []any:
|
||||
caps := make([]string, len(tcaps))
|
||||
for i, c := range tcaps {
|
||||
str, ok := c.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
caps[i] = str
|
||||
}
|
||||
config.Capabilities = append(config.Capabilities, caps...)
|
||||
}
|
||||
}
|
||||
|
||||
strFromInfo := func(k string) string {
|
||||
v, ok := r.Info[k]
|
||||
if ok {
|
||||
val := v.(string)
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
vFromInfo := func(k string) float64 {
|
||||
v, ok := r.Info[k]
|
||||
if ok {
|
||||
val := v.(float64)
|
||||
return val
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
config.ModelFamily = strFromInfo("model_family")
|
||||
if config.ModelFamily != "" {
|
||||
config.ModelFamilies = []string{config.ModelFamily}
|
||||
}
|
||||
|
||||
config.BaseName = strFromInfo("base_name")
|
||||
config.FileType = strFromInfo("quantization_level")
|
||||
config.ModelType = strFromInfo("parameter_size")
|
||||
config.ContextLen = int(vFromInfo("context_length"))
|
||||
config.EmbedLen = int(vFromInfo("embedding_length"))
|
||||
}
|
||||
|
||||
if err := createModel(r, name, baseLayers, config, fn); err != nil {
|
||||
if errors.Is(err, errBadTemplate) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
@@ -154,6 +231,51 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
func remoteURL(raw string) (string, error) {
|
||||
// Special‑case: user supplied only a path ("/foo/bar").
|
||||
if strings.HasPrefix(raw, "/") {
|
||||
return (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort("localhost", "11434"),
|
||||
Path: path.Clean(raw),
|
||||
}).String(), nil
|
||||
}
|
||||
|
||||
if !strings.Contains(raw, "://") {
|
||||
raw = "http://" + raw
|
||||
}
|
||||
|
||||
if raw == "ollama.com" || raw == "http://ollama.com" {
|
||||
raw = "https://ollama.com:443"
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse error: %w", err)
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
u.Host = "localhost"
|
||||
}
|
||||
|
||||
hostPart, portPart, err := net.SplitHostPort(u.Host)
|
||||
if err == nil {
|
||||
u.Host = net.JoinHostPort(hostPart, portPart)
|
||||
} else {
|
||||
u.Host = net.JoinHostPort(u.Host, "11434")
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
u.Path = path.Clean(u.Path)
|
||||
}
|
||||
|
||||
if u.Path == "/" {
|
||||
u.Path = ""
|
||||
}
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
|
||||
switch detectModelTypeFromFiles(files) {
|
||||
case "safetensors":
|
||||
@@ -316,15 +438,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, fn func(resp api.ProgressResponse)) (err error) {
|
||||
config := ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||
var layers []Layer
|
||||
for _, layer := range baseLayers {
|
||||
if layer.GGML != nil {
|
||||
@@ -404,7 +518,7 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||
return err
|
||||
}
|
||||
|
||||
configLayer, err := createConfigLayer(layers, config)
|
||||
configLayer, err := createConfigLayer(layers, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -104,3 +104,154 @@ func TestConvertFromSafetensors(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "absolute path",
|
||||
input: "/foo/bar",
|
||||
expected: "http://localhost:11434/foo/bar",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path with cleanup",
|
||||
input: "/foo/../bar",
|
||||
expected: "http://localhost:11434/bar",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "root path",
|
||||
input: "/",
|
||||
expected: "http://localhost:11434/",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host without scheme",
|
||||
input: "example.com",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host with port",
|
||||
input: "example.com:8080",
|
||||
expected: "http://example.com:8080",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "full URL",
|
||||
input: "https://example.com:8080/path",
|
||||
expected: "https://example.com:8080/path",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "full URL with path cleanup",
|
||||
input: "https://example.com:8080/path/../other",
|
||||
expected: "https://example.com:8080/other",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "ollama.com special case",
|
||||
input: "ollama.com",
|
||||
expected: "https://ollama.com:443",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "http ollama.com special case",
|
||||
input: "http://ollama.com",
|
||||
expected: "https://ollama.com:443",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with only host",
|
||||
input: "http://example.com",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with root path cleaned",
|
||||
input: "http://example.com/",
|
||||
expected: "http://example.com:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid URL",
|
||||
input: "http://[::1]:namedport", // invalid port
|
||||
expected: "",
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "http://localhost:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "host with scheme but no port",
|
||||
input: "http://localhost",
|
||||
expected: "http://localhost:11434",
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "complex path cleanup",
|
||||
input: "/a/b/../../c/./d",
|
||||
expected: "http://localhost:11434/c/d",
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := remoteURL(tt.input)
|
||||
|
||||
if tt.hasError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteURL_Idempotent(t *testing.T) {
|
||||
// Test that applying remoteURL twice gives the same result as applying it once
|
||||
testInputs := []string{
|
||||
"/foo/bar",
|
||||
"example.com",
|
||||
"https://example.com:8080/path",
|
||||
"ollama.com",
|
||||
"http://localhost:11434",
|
||||
}
|
||||
|
||||
for _, input := range testInputs {
|
||||
t.Run(input, func(t *testing.T) {
|
||||
firstResult, err := remoteURL(input)
|
||||
if err != nil {
|
||||
t.Fatalf("first call failed: %v", err)
|
||||
}
|
||||
|
||||
secondResult, err := remoteURL(firstResult)
|
||||
if err != nil {
|
||||
t.Fatalf("second call failed: %v", err)
|
||||
}
|
||||
|
||||
if firstResult != secondResult {
|
||||
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
@@ -73,29 +74,38 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
if m.ModelPath != "" {
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
} else if len(m.Config.Capabilities) > 0 {
|
||||
for _, c := range m.Config.Capabilities {
|
||||
capabilities = append(capabilities, model.Capability(c))
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
slog.Warn("unknown capabilities for model", "model", m.Name)
|
||||
}
|
||||
|
||||
if m.Template == nil {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
builtinParser := parsers.ParserForName(m.Config.Parser)
|
||||
// Check for tools capability
|
||||
if slices.Contains(m.Template.Vars(), "tools") {
|
||||
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
|
||||
capabilities = append(capabilities, model.CapabilityTools)
|
||||
}
|
||||
|
||||
@@ -109,10 +119,16 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Skip the thinking check if it's already set
|
||||
if slices.Contains(capabilities, "thinking") {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
hasTags := openingTag != "" && closingTag != ""
|
||||
if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) {
|
||||
isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily)
|
||||
if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
@@ -198,6 +214,20 @@ func (m *Model) String() string {
|
||||
})
|
||||
}
|
||||
|
||||
if m.Config.Renderer != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "renderer",
|
||||
Args: m.Config.Renderer,
|
||||
})
|
||||
}
|
||||
|
||||
if m.Config.Parser != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "parser",
|
||||
Args: m.Config.Parser,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
switch v := v.(type) {
|
||||
case []any:
|
||||
@@ -236,8 +266,19 @@ type ConfigV2 struct {
|
||||
ModelFormat string `json:"model_format"`
|
||||
ModelFamily string `json:"model_family"`
|
||||
ModelFamilies []string `json:"model_families"`
|
||||
ModelType string `json:"model_type"`
|
||||
FileType string `json:"file_type"`
|
||||
ModelType string `json:"model_type"` // shown as Parameter Size
|
||||
FileType string `json:"file_type"` // shown as Quantization Level
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
RemoteHost string `json:"remote_host,omitempty"`
|
||||
RemoteModel string `json:"remote_model,omitempty"`
|
||||
|
||||
// used for remotes
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
ContextLen int `json:"context_length,omitempty"`
|
||||
EmbedLen int `json:"embedding_length,omitempty"`
|
||||
BaseName string `json:"base_name,omitempty"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
|
||||
@@ -25,10 +25,7 @@ func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
||||
if d > maxBackoff {
|
||||
d = maxBackoff
|
||||
}
|
||||
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
@@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
if think != nil {
|
||||
thinkVal = think.Bool()
|
||||
thinkLevel = think.String()
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, b.String())
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return p, images, nil
|
||||
}
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
if m.Config.Renderer != "" {
|
||||
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rendered, nil
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
@@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
thinkVal = think.Bool()
|
||||
thinkLevel = think.String()
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return b.String(), images, nil
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
628
server/routes.go
628
server/routes.go
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"slices"
|
||||
@@ -28,15 +30,15 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -47,6 +49,20 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
@@ -137,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
h, _ := os.Hostname()
|
||||
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
@@ -177,6 +204,90 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
origModel := req.Model
|
||||
|
||||
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
||||
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = m.Config.RemoteModel
|
||||
|
||||
if req.Template == "" && m.Template.String() != "" {
|
||||
req.Template = m.Template.String()
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
if _, ok := req.Options[k]; !ok {
|
||||
req.Options[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// update the system prompt from the model if one isn't already specified
|
||||
if req.System == "" && m.System != "" {
|
||||
req.System = m.System
|
||||
}
|
||||
|
||||
if len(m.Messages) > 0 {
|
||||
slog.Warn("embedded messages in the model not supported with '/api/generate'; try '/api/chat' instead")
|
||||
}
|
||||
|
||||
fn := func(resp api.GenerateResponse) error {
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||
err = client.Generate(c, &req, fn)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
var apiError api.StatusError
|
||||
if errors.As(err, &apiError) {
|
||||
c.JSON(apiError.StatusCode, apiError)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
@@ -196,21 +307,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||
var parserType parser.TokenParserType
|
||||
if useHarmony {
|
||||
parserType = parser.TokenParserTypeHarmony
|
||||
} else {
|
||||
parserType = parser.TokenParserTypeDefault
|
||||
}
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
|
||||
if useHarmony {
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
var builtinParser parsers.Parser
|
||||
if shouldUseHarmony(m) && m.Config.Parser == "" {
|
||||
m.Config.Parser = "harmony"
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
if !req.Raw && m.Config.Parser != "" {
|
||||
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||
if builtinParser != nil {
|
||||
// no tools or last message for generate endpoint
|
||||
builtinParser.Init(nil, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||
return
|
||||
}
|
||||
@@ -322,10 +433,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
DebugInfo: &api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
@@ -334,13 +445,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
if !useHarmony {
|
||||
if builtinParser == nil {
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinking.Parser{
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
}
|
||||
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
||||
thinkingState.AddContent(openingTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,19 +464,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
ParserType: parserType,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: cr.Content,
|
||||
Done: cr.Done,
|
||||
Thinking: cr.Thinking,
|
||||
ToolCalls: cr.ToolCalls,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: cr.PromptEvalCount,
|
||||
PromptEvalDuration: cr.PromptEvalDuration,
|
||||
@@ -371,22 +482,18 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if res.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
for i, tool := range res.ToolCalls {
|
||||
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||
if builtinParser != nil {
|
||||
content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
res.Response = content
|
||||
res.Thinking = thinking
|
||||
if cr.Done && len(toolCalls) > 0 {
|
||||
res.ToolCalls = toolCalls
|
||||
}
|
||||
return
|
||||
}
|
||||
if thinkingState != nil {
|
||||
} else if thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
@@ -397,6 +504,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
if !req.Raw {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
if err != nil {
|
||||
@@ -407,7 +518,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
if builtinParser != nil {
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 {
|
||||
ch <- res
|
||||
@@ -470,7 +581,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
@@ -533,11 +643,27 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||
if ctxLen <= 0 {
|
||||
// return error if the truncated input would be empty or just special tokens
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -904,6 +1030,28 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" {
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
|
||||
if m.Config.ModelFamily != "" {
|
||||
resp.ModelInfo = make(map[string]any)
|
||||
resp.ModelInfo["general.architecture"] = m.Config.ModelFamily
|
||||
|
||||
if m.Config.BaseName != "" {
|
||||
resp.ModelInfo["general.basename"] = m.Config.BaseName
|
||||
}
|
||||
|
||||
if m.Config.ContextLen > 0 {
|
||||
resp.ModelInfo[fmt.Sprintf("%s.context_length", m.Config.ModelFamily)] = m.Config.ContextLen
|
||||
}
|
||||
|
||||
if m.Config.EmbedLen > 0 {
|
||||
resp.ModelInfo[fmt.Sprintf("%s.embedding_length", m.Config.ModelFamily)] = m.Config.EmbedLen
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var params []string
|
||||
cs := 30
|
||||
for k, v := range m.Options {
|
||||
@@ -934,6 +1082,11 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
fmt.Fprint(&sb, m.String())
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
// skip loading tensor information if this is a remote model
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1010,11 +1163,13 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||
|
||||
// tag should never be masked
|
||||
models = append(models, api.ListModelResponse{
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
RemoteModel: cf.RemoteModel,
|
||||
RemoteHost: cf.RemoteHost,
|
||||
Size: m.Size(),
|
||||
Digest: m.digest,
|
||||
ModifiedAt: m.fi.ModTime(),
|
||||
Details: api.ModelDetails{
|
||||
Format: cf.ModelFormat,
|
||||
Family: cf.ModelFamily,
|
||||
@@ -1072,6 +1227,108 @@ func (s *Server) CopyHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) WebSearchHandler(c *gin.Context) {
|
||||
var req api.SearchRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Query == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "query is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxResults <= 0 {
|
||||
req.MaxResults = 5
|
||||
}
|
||||
|
||||
results, err := s.callWebSearchAPI(req.Query, req.MaxResults)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if len(results) > req.MaxResults {
|
||||
results = results[:req.MaxResults]
|
||||
}
|
||||
|
||||
resp := api.SearchResponse{
|
||||
Results: results,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) callWebSearchAPI(query string, maxResults int) ([]api.SearchResult, error) {
|
||||
searchReq := api.SearchRequest{
|
||||
Query: query,
|
||||
MaxResults: maxResults,
|
||||
}
|
||||
|
||||
client := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient)
|
||||
|
||||
searchResp, err := client.WebSearch(context.Background(), &searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return searchResp.Results, nil
|
||||
}
|
||||
|
||||
func (s *Server) FetchHandler(c *gin.Context) {
|
||||
var req api.FetchRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
if req.URL == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Call the real web fetch API
|
||||
content, title, err := s.callWebFetchAPI(req.URL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.FetchResponse{
|
||||
Content: content,
|
||||
Title: title,
|
||||
URL: req.URL,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) callWebFetchAPI(targetURL string) (string, string, error) {
|
||||
// Create request to ollama.com web fetch API
|
||||
fetchReq := api.FetchRequest{
|
||||
URL: targetURL,
|
||||
}
|
||||
|
||||
// Create client to call ollama.com
|
||||
client := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient)
|
||||
|
||||
// Call the web fetch API
|
||||
fetchResp, err := client.Fetch(context.Background(), &fetchReq)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return fetchResp.Content, fetchResp.Title, nil
|
||||
}
|
||||
|
||||
func (s *Server) HeadBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
@@ -1274,6 +1531,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
|
||||
r.POST("/api/me", s.WhoamiHandler)
|
||||
|
||||
r.POST("/api/signout", s.SignoutHandler)
|
||||
// deprecated
|
||||
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
||||
|
||||
// Create
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
@@ -1286,6 +1549,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/web_search", s.WebSearchHandler)
|
||||
r.POST("/api/web_fetch", s.FetchHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
@@ -1470,6 +1735,70 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
client := api.NewClient(u, http.DefaultClient)
|
||||
user, err := client.Whoami(c)
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
}
|
||||
|
||||
// user isn't signed in
|
||||
if user != nil && user.Name == "" {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
slog.Error("couldn't get public key", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||
return
|
||||
}
|
||||
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "URL parse error"})
|
||||
return
|
||||
}
|
||||
|
||||
client := api.NewClient(u, http.DefaultClient)
|
||||
err = client.Disconnect(c, encKey)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, nil)
|
||||
}
|
||||
|
||||
func (s *Server) PsHandler(c *gin.Context) {
|
||||
models := []api.ProcessModelResponse{}
|
||||
|
||||
@@ -1526,21 +1855,34 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == errtypes.InvalidModelNameErrMsg:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
s.sched.expireRunner(model)
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Model: req.Model,
|
||||
@@ -1552,6 +1894,83 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
origModel := req.Model
|
||||
|
||||
remoteURL, err := url.Parse(m.Config.RemoteHost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(envconfig.Remotes(), remoteURL.Hostname()) {
|
||||
slog.Info("remote model", "remotes", envconfig.Remotes(), "remoteURL", m.Config.RemoteHost, "hostname", remoteURL.Hostname())
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "this server cannot run this remote model"})
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = m.Config.RemoteModel
|
||||
if req.Options == nil {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
|
||||
msgs := append(m.Messages, req.Messages...)
|
||||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
req.Messages = msgs
|
||||
|
||||
for k, v := range m.Options {
|
||||
if _, ok := req.Options[k]; !ok {
|
||||
req.Options[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
fn := func(resp api.ChatResponse) error {
|
||||
resp.Model = origModel
|
||||
resp.RemoteModel = m.Config.RemoteModel
|
||||
resp.RemoteHost = m.Config.RemoteHost
|
||||
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = c.Writer.Write(append(data, '\n')); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||
err = client.Chat(c, &req, fn)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
var apiError api.StatusError
|
||||
if errors.As(err, &apiError) {
|
||||
c.JSON(apiError.StatusCode, apiError)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
caps := []model.Capability{model.CapabilityCompletion}
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
@@ -1560,17 +1979,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
@@ -1599,27 +2007,23 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||
var parserType parser.TokenParserType
|
||||
if useHarmony {
|
||||
parserType = parser.TokenParserTypeHarmony
|
||||
} else {
|
||||
parserType = parser.TokenParserTypeDefault
|
||||
if shouldUseHarmony(m) && m.Config.Parser == "" {
|
||||
m.Config.Parser = "harmony"
|
||||
}
|
||||
|
||||
var builtinParser parsers.Parser
|
||||
processedTools := req.Tools
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
var prefillString string
|
||||
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
|
||||
if useHarmony {
|
||||
prefillString = harmony.Prefill(msgs[len(msgs)-1])
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||
// renamed to be valid Harmony function names.
|
||||
processedTools = make([]api.Tool, len(req.Tools))
|
||||
copy(processedTools, req.Tools)
|
||||
for i, tool := range processedTools {
|
||||
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
|
||||
if m.Config.Parser != "" {
|
||||
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||
if builtinParser != nil {
|
||||
// Determine last message for chat prefill
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
// Initialize parser and get processed tools
|
||||
processedTools = builtinParser.Init(req.Tools, lastMessage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1632,10 +2036,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
DebugInfo: &api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
@@ -1643,8 +2047,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
// Validate Think value: string values currently only allowed for harmony/gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())})
|
||||
return
|
||||
}
|
||||
@@ -1663,7 +2067,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 && !useHarmony {
|
||||
if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) {
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
@@ -1672,17 +2076,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
ParserType: parserType,
|
||||
PrefillString: prefillString,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
@@ -1697,14 +2099,26 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
for i, tool := range res.Message.ToolCalls {
|
||||
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var stream bool = false
|
||||
@@ -615,6 +617,78 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test",
|
||||
From: "bob",
|
||||
RemoteHost: "https://ollama.com",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion", "tools", "thinking"},
|
||||
"model_family": "gptoss",
|
||||
"context_length": 131072,
|
||||
"embedding_length": 2880,
|
||||
"quantization_level": "MXFP4",
|
||||
"parameter_size": "20.9B",
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("exected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("exected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedDetails := api.ModelDetails{
|
||||
ParentModel: "",
|
||||
Format: "",
|
||||
Family: "gptoss",
|
||||
Families: []string{"gptoss"},
|
||||
ParameterSize: "20.9B",
|
||||
QuantizationLevel: "MXFP4",
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(resp.Details, expectedDetails) {
|
||||
t.Errorf("model details: expected %#v, actual %#v", expectedDetails, resp.Details)
|
||||
}
|
||||
|
||||
expectedCaps := []model.Capability{
|
||||
model.Capability("completion"),
|
||||
model.Capability("tools"),
|
||||
model.Capability("thinking"),
|
||||
}
|
||||
|
||||
if !slices.Equal(resp.Capabilities, expectedCaps) {
|
||||
t.Errorf("capabilities: expected %#v, actual %#v", expectedCaps, resp.Capabilities)
|
||||
}
|
||||
|
||||
v, ok := resp.ModelInfo["gptoss.context_length"]
|
||||
ctxlen := v.(float64)
|
||||
if !ok || int(ctxlen) != 131072 {
|
||||
t.Errorf("context len: expected %d, actual %d", 131072, int(ctxlen))
|
||||
}
|
||||
|
||||
v, ok = resp.ModelInfo["gptoss.embedding_length"]
|
||||
embedlen := v.(float64)
|
||||
if !ok || int(embedlen) != 2880 {
|
||||
t.Errorf("embed len: expected %d, actual %d", 2880, int(embedlen))
|
||||
}
|
||||
|
||||
fmt.Printf("resp = %#v\n", resp)
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
var response api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
@@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
var response api.ChatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "content streams as it arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "Hello", Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
||||
wantContent: "Hello",
|
||||
},
|
||||
{
|
||||
@@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
wantContent: ", world",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "!",
|
||||
},
|
||||
},
|
||||
@@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "thinking streams separately from content",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
||||
wantThinking: "Thinking...",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "Answer", Done: false},
|
||||
wantContent: "Answer",
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
||||
// No output expected - just closes the analysis message and resets state to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
||||
wantContent: "Answer", // After message end, state is reset to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
// No output expected - just closes the assistant message
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "partial tags buffer until complete",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
||||
// No output - partial tag
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
||||
// No output - still building tags
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
||||
wantThinking: "Deep ",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "thought", Done: false},
|
||||
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
||||
wantThinking: "thought",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done",
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done", // After message end, state is reset to normal
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "simple assistant after analysis",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Answer",
|
||||
wantThinking: "Think",
|
||||
},
|
||||
@@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "tool call parsed and returned correctly",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "The weather is sunny",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "tool call with streaming JSON across chunks",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
||||
// No output yet - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
|
||||
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
||||
// Still no output - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
@@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockResponses := []llm.CompletionResponse{
|
||||
{Content: "First ", Done: false},
|
||||
{Content: "<|message|>First ", Done: false},
|
||||
{Content: "chunk ", Done: false},
|
||||
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
@@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type expectedChunk struct {
|
||||
afterResponse int // Which mock response this chunk should appear after
|
||||
content string // Expected content in this chunk
|
||||
thinking string // Expected thinking in this chunk
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockResponses []llm.CompletionResponse
|
||||
expectedChunks []expectedChunk
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}{
|
||||
{
|
||||
name: "simple message without thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
||||
{Content: "how can I help?", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 1, content: "Hello, "},
|
||||
{afterResponse: 2, content: "how can I help?"},
|
||||
},
|
||||
wantContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "message with analysis channel for thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|channel|>analysis<|message|>", Done: false},
|
||||
{Content: "Let me think ", Done: false},
|
||||
{Content: "about this problem...", Done: false},
|
||||
{Content: "<|end|>", Done: false},
|
||||
{Content: "<|start|>assistant<|message|>", Done: false},
|
||||
{Content: "The answer ", Done: false},
|
||||
{Content: "is 42", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 2, thinking: "Let me think "},
|
||||
{afterResponse: 3, thinking: "about this problem..."},
|
||||
{afterResponse: 6, content: "The answer "},
|
||||
{afterResponse: 7, content: "is 42"},
|
||||
},
|
||||
wantContent: "The answer is 42",
|
||||
wantThinking: "Let me think about this problem...",
|
||||
},
|
||||
{
|
||||
name: "streaming with partial tags across boundaries",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|chan", Done: false},
|
||||
{Content: "nel|>analy", Done: false},
|
||||
{Content: "sis<|mess", Done: false},
|
||||
{Content: "age|>Think", Done: false},
|
||||
{Content: "ing deeply...<|end|>", Done: false},
|
||||
{Content: "<|start|>assi", Done: false},
|
||||
{Content: "stant<|message|>Result ", Done: false},
|
||||
{Content: "computed<|e", Done: false},
|
||||
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 4, thinking: "Think"},
|
||||
{afterResponse: 5, thinking: "ing deeply..."},
|
||||
{afterResponse: 7, content: "Result "},
|
||||
{afterResponse: 8, content: "computed"},
|
||||
},
|
||||
wantContent: "Result computed",
|
||||
wantThinking: "Thinking deeply...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Channel to synchronize mock responses with chunk verification
|
||||
responsesSent := make(chan int, len(tc.mockResponses))
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Send mock responses one at a time, notifying when each is sent
|
||||
for i, resp := range tc.mockResponses {
|
||||
fn(resp)
|
||||
responsesSent <- i + 1
|
||||
}
|
||||
close(responsesSent)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a minimal model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
// Create model with passthrough template
|
||||
stream := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse streaming response
|
||||
var chunks []api.ChatResponse
|
||||
var content, thinking strings.Builder
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
// Accumulate content and thinking from each chunk
|
||||
content.WriteString(chunk.Message.Content)
|
||||
thinking.WriteString(chunk.Message.Thinking)
|
||||
|
||||
// Debug output
|
||||
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got streaming chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
gotContent := content.String()
|
||||
gotThinking := thinking.String()
|
||||
|
||||
if gotContent != tc.wantContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
||||
}
|
||||
if gotThinking != tc.wantThinking {
|
||||
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
||||
}
|
||||
|
||||
// Verify last chunk has done=true
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
if !lastChunk.Done {
|
||||
t.Error("expected last chunk to have done=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,7 +126,15 @@ func TestRoutes(t *testing.T) {
|
||||
t.Fatalf("failed to create model: %v", err)
|
||||
}
|
||||
|
||||
if err := createModel(r, modelName, baseLayers, fn); err != nil {
|
||||
config := &ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
}
|
||||
|
||||
if err := createModel(r, modelName, baseLayers, config, fn); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -382,10 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
|
||||
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
|
||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
|
||||
numParallel := int(envconfig.NumParallel())
|
||||
if numParallel < 1 {
|
||||
numParallel = 1
|
||||
}
|
||||
numParallel := max(int(envconfig.NumParallel()), 1)
|
||||
|
||||
// Embedding models should always be loaded with parallel=1
|
||||
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||
|
||||
@@ -273,9 +273,21 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
||||
if args, ok := obj["arguments"].(map[string]any); ok {
|
||||
return args, true
|
||||
}
|
||||
if argsStr, ok := obj["arguments"].(string); ok {
|
||||
var argsData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||
return argsData, ok
|
||||
}
|
||||
}
|
||||
if args, ok := obj["parameters"].(map[string]any); ok {
|
||||
return args, true
|
||||
}
|
||||
if argsStr, ok := obj["parameters"].(string); ok {
|
||||
var argsData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||
return argsData, ok
|
||||
}
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
|
||||
@@ -1274,6 +1274,22 @@ func TestFindArguments(t *testing.T) {
|
||||
"items": []any{"{", "}", map[string]any{"key": "value"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "stringified arguments",
|
||||
buffer: []byte(`{"name": "get_temperature", "arguments": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
|
||||
want: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "stringified parameters",
|
||||
buffer: []byte(`{"name": "get_temperature", "parameters": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
|
||||
want: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
Reference in New Issue
Block a user