Compare commits

..

8 Commits

Author SHA1 Message Date
Michael Yang
b59053a883 presets 2025-11-19 17:26:18 -08:00
Michael Yang
bb93e5afe7 errorlint 2025-11-19 17:26:18 -08:00
Michael Yang
4d24d8a77d gocritic 2025-11-19 17:26:18 -08:00
Michael Yang
f01c83ed6d fmt 2025-11-19 17:26:18 -08:00
Michael Yang
d3228355be staticcheck 2025-11-19 17:26:18 -08:00
Michael Yang
78a75a30d8 prealloc 2025-11-19 17:26:18 -08:00
Michael Yang
974ae8ef84 perfsprint 2025-11-19 17:26:17 -08:00
Michael Yang
efd9f5e67e modernize 2025-11-19 17:26:17 -08:00
123 changed files with 578 additions and 1070 deletions

2
.gitattributes vendored
View File

@@ -19,8 +19,6 @@ ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -11,6 +11,7 @@ linters:
- errorlint
- exptostd
- gocheckcompilerdirectives
- gocritic
- govet
- ineffassign
- intrange
@@ -35,6 +36,12 @@ linters:
errcheck:
exclude-functions:
- fmt.Fprintf
gocritic:
disabled-checks:
# Detects suspicious duplicated sub-expressions.
# Prone to false positives when used on cgo code
# https://github.com/go-critic/go-critic/issues/897#issuecomment-568892104
- dupSubExpr
perfsprint:
strconcat: false
concat-loop: false
@@ -44,24 +51,22 @@ linters:
# Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019
- -SA1019
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- -ST1000
# Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003
- -ST1003
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- -ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- -ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- -ST1022
usestdlibvars:
http-method: false
http-status-code: false
exclusions:
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- path: _test\.go
linters:
- prealloc
formatters:
enable:

View File

@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode == http.StatusUnauthorized {

View File

@@ -2,6 +2,7 @@ package api
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
@@ -39,7 +40,7 @@ func TestClientFromEnvironment(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
client, err := ClientFromEnvironment()
if err != v.err {
if !errors.Is(err, v.err) {
t.Fatalf("expected %s, got %s", v.err, err)
}
@@ -55,7 +56,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -112,20 +112,6 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -150,12 +136,6 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -194,10 +174,9 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
name string
response any
wantErr string
}{
{
name: "immediate error response",
@@ -205,8 +184,7 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
wantErr: "test error message",
},
{
name: "server error response",
@@ -214,8 +192,7 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
wantErr: "internal error",
},
{
name: "successful response",
@@ -227,26 +204,6 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -254,16 +211,11 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
@@ -290,15 +242,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -2,6 +2,7 @@ package api
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"math"
@@ -308,9 +309,9 @@ func (tp ToolProperty) ToTypeScriptType() string {
return mapToTypeScriptType(tp.Type[0])
}
var types []string
for _, t := range tp.Type {
types = append(types, mapToTypeScriptType(t))
types := make([]string, len(tp.Type))
for i, t := range tp.Type {
types[i] = mapToTypeScriptType(t)
}
return strings.Join(types, " | ")
}
@@ -783,7 +784,7 @@ func (m *Metrics) Summary() {
func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
@@ -854,8 +855,7 @@ func (opts *Options) FromMap(m map[string]any) error {
}
field.Set(reflect.ValueOf(slice))
case reflect.Pointer:
var b bool
if field.Type() == reflect.TypeOf(&b) {
if field.Type() == reflect.TypeFor[*bool]() {
val, ok := val.(bool)
if !ok {
return fmt.Errorf("option %q must be of type boolean", key)
@@ -906,7 +906,7 @@ func DefaultOptions() Options {
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
type ThinkValue struct {
// Value can be a bool or string
Value interface{}
Value any
}
// IsValid checks if the ThinkValue is valid
@@ -999,7 +999,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
return nil
}
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
return errors.New("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
}
// MarshalJSON implements json.Marshaler
@@ -1018,7 +1018,7 @@ func (d Duration) MarshalJSON() ([]byte, error) {
if d.Duration < 0 {
return []byte("-1"), nil
}
return []byte("\"" + d.Duration.String() + "\""), nil
return []byte("\"" + d.String() + "\""), nil
}
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
@@ -1045,7 +1045,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
d.Duration = time.Duration(math.MaxInt64)
}
default:
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
return fmt.Errorf("unsupported type: '%s'", reflect.TypeOf(v))
}
return nil
@@ -1055,7 +1055,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
func FormatParams(params map[string][]string) (map[string]any, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
@@ -1102,8 +1102,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
// TODO: only string slices are supported right now
out[key] = vals
case reflect.Pointer:
var b bool
if field.Type() == reflect.TypeOf(&b) {
if field.Type() == reflect.TypeFor[*bool]() {
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)

View File

@@ -397,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
return
}
@@ -466,8 +466,6 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage("/")
}
}

View File

@@ -24,14 +24,27 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -247,7 +260,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
[[NSWorkspace sharedWorkspace] openURL:url];
}

View File

@@ -147,9 +147,7 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage("/")
}
}

View File

@@ -22,6 +22,7 @@ import (
var ErrCancelled = errors.New("Cancelled")
// Cancelled refers to ErrCancelled.
//
// Deprecated: Use ErrCancelled instead.
var Cancelled = ErrCancelled
@@ -37,7 +38,7 @@ type MsgBuilder struct {
}
// Message initialises a MsgBuilder with the provided message.
func Message(format string, args ...interface{}) *MsgBuilder {
func Message(format string, args ...any) *MsgBuilder {
return &MsgBuilder{Msg: fmt.Sprintf(format, args...)}
}

View File

@@ -319,7 +319,7 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("timeout scanning server log for inference compute details")
return nil, errors.New("timeout scanning server log for inference compute details")
default:
}
file, err := os.Open(serverLogPath)
@@ -345,11 +345,9 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
slog.Info("Matched", "inference compute", ic)
inference = append(inference, ic)
} else {
} else if len(inference) > 0 {
// Break out on first non matching line after we start matching
if len(inference) > 0 {
return inference, nil
}
return inference, nil
}
}
time.Sleep(100 * time.Millisecond)

View File

@@ -31,7 +31,7 @@ func terminate(proc *os.Process) error {
func terminated(pid int) (bool, error) {
proc, err := os.FindProcess(pid)
if err != nil {
return false, fmt.Errorf("failed to find process: %v", err)
return false, fmt.Errorf("failed to find process: %w", err)
}
err = proc.Signal(syscall.Signal(0))
@@ -40,7 +40,7 @@ func terminated(pid int) (bool, error) {
return true, nil
}
return false, fmt.Errorf("error signaling process: %v", err)
return false, fmt.Errorf("error signaling process: %w", err)
}
return false, nil
@@ -67,8 +67,7 @@ func reapServers() error {
return nil
}
pids := strings.Split(pidsStr, "\n")
for _, pidStr := range pids {
for pidStr := range strings.SplitSeq(pidsStr, "\n") {
pidStr = strings.TrimSpace(pidStr)
if pidStr == "" {
continue

View File

@@ -5,6 +5,7 @@ package store
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
@@ -482,7 +483,8 @@ func (db *database) cleanupOrphanedData() error {
}
func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
var sqlite3Err sqlite3.Error
if errors.As(err, &sqlite3Err) {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
@@ -490,7 +492,8 @@ func duplicateColumnError(err error) bool {
}
func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
var sqlite3Err sqlite3.Error
if errors.As(err, &sqlite3Err) {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
@@ -586,8 +589,8 @@ func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Cha
&browserState,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("chat not found")
if errors.Is(err, sql.ErrNoRows) {
return nil, errors.New("chat not found")
}
return nil, fmt.Errorf("query chat: %w", err)
}
@@ -752,7 +755,7 @@ func (db *database) updateLastMessage(chatID string, msg Message) error {
return fmt.Errorf("get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("no message found to update")
return errors.New("no message found to update")
}
_, err = tx.Exec("DELETE FROM attachments WHERE message_id = ?", messageID)

View File

@@ -282,7 +282,7 @@ func countRows(t *testing.T, db *database, table string) int {
return count
}
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...interface{}) int {
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...any) int {
t.Helper()
var count int
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition)
@@ -296,7 +296,7 @@ func countRowsWithCondition(t *testing.T, db *database, table, condition string,
// Test helpers for schema migration testing
// schemaMap returns both tables/columns and indexes (ignoring order)
func schemaMap(db *database) map[string]interface{} {
func schemaMap(db *database) map[string]any {
result := make(map[string]any)
result["tables"] = columnMap(db)

View File

@@ -5,6 +5,7 @@ package store
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"path/filepath"
@@ -26,7 +27,7 @@ func (i *Image) Bytes() ([]byte, error) {
// ImgBytes reads image data from the specified file path
func ImgBytes(path string) ([]byte, error) {
if path == "" {
return nil, fmt.Errorf("empty image path")
return nil, errors.New("empty image path")
}
data, err := os.ReadFile(path)

View File

@@ -4,6 +4,7 @@ package tools
import (
"context"
"errors"
"fmt"
"net/url"
"regexp"
@@ -130,7 +131,7 @@ func (b *BrowserSearch) Schema() map[string]any {
func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
query, ok := args["query"].(string)
if !ok {
return nil, "", fmt.Errorf("query parameter is required")
return nil, "", errors.New("query parameter is required")
}
topn, ok := args["topn"].(int)
@@ -150,7 +151,7 @@ func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any,
searchResponse, ok := result.(*WebSearchResponse)
if !ok {
return nil, "", fmt.Errorf("invalid search results format")
return nil, "", errors.New("invalid search results format")
}
// Build main search results page that contains all search results
@@ -383,15 +384,9 @@ func wrapLines(text string, width int) []string {
wrapped = append(wrapped, "")
} else if len(line) <= width {
wrapped = append(wrapped, line)
} else if words := strings.Fields(line); len(words) == 0 {
wrapped = append(wrapped, line)
} else {
// Word wrapping while preserving whitespace structure
words := strings.Fields(line)
if len(words) == 0 {
// Line with only whitespace
wrapped = append(wrapped, line)
continue
}
currentLine := ""
for _, word := range words {
// Check if adding this word would exceed width
@@ -536,15 +531,13 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
if err != nil {
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
}
} else {
} else if len(b.state.Data.PageStack) != 0 {
// get last page
if len(b.state.Data.PageStack) != 0 {
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
var err error
page, err = b.getPageFromStack(pageURL)
if err != nil {
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
}
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
var err error
page, err = b.getPageFromStack(pageURL)
if err != nil {
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
}
}
@@ -594,7 +587,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
// Try to get id as integer (link ID from current page)
if id, ok := args["id"].(float64); ok {
if page == nil {
return nil, "", fmt.Errorf("no current page to resolve link from")
return nil, "", errors.New("no current page to resolve link from")
}
idInt := int(id)
pageURL, ok := page.Links[idInt]
@@ -637,7 +630,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
// If no id provided, just display current page
if page == nil {
return nil, "", fmt.Errorf("no current page to display")
return nil, "", errors.New("no current page to display")
}
// Only add to PageStack without updating URLToPage
b.state.Data.PageStack = append(b.state.Data.PageStack, page.URL)
@@ -742,7 +735,7 @@ func (b *BrowserFind) Schema() map[string]any {
func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, string, error) {
pattern, ok := args["pattern"].(string)
if !ok {
return nil, "", fmt.Errorf("pattern parameter is required")
return nil, "", errors.New("pattern parameter is required")
}
// Get cursor parameter if provided, default to current page
@@ -756,7 +749,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
if cursor == -1 {
// Use current page
if len(b.state.Data.PageStack) == 0 {
return nil, "", fmt.Errorf("no pages to search in")
return nil, "", errors.New("no pages to search in")
}
var err error
page, err = b.getPageFromStack(b.state.Data.PageStack[len(b.state.Data.PageStack)-1])
@@ -776,7 +769,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
}
if page == nil {
return nil, "", fmt.Errorf("page not found")
return nil, "", errors.New("page not found")
}
// Create find results page

View File

@@ -5,6 +5,7 @@ package tools
import (
"context"
"encoding/json"
"errors"
"fmt"
)
@@ -87,7 +88,7 @@ func (g *BrowserCrawler) Schema() map[string]any {
func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) {
urlsRaw, ok := args["urls"].([]any)
if !ok {
return nil, fmt.Errorf("urls parameter is required and must be an array of strings")
return nil, errors.New("urls parameter is required and must be an array of strings")
}
urls := make([]string, 0, len(urlsRaw))
@@ -98,7 +99,7 @@ func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*Cra
}
if len(urls) == 0 {
return nil, fmt.Errorf("at least one URL is required")
return nil, errors.New("at least one URL is required")
}
return g.performWebCrawl(ctx, urls)

View File

@@ -5,6 +5,7 @@ package tools
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
@@ -84,7 +85,7 @@ func (w *BrowserWebSearch) Schema() map[string]any {
func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) {
queriesRaw, ok := args["queries"].([]any)
if !ok {
return nil, fmt.Errorf("queries parameter is required and must be an array of strings")
return nil, errors.New("queries parameter is required and must be an array of strings")
}
queries := make([]string, 0, len(queriesRaw))
@@ -95,7 +96,7 @@ func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (an
}
if len(queries) == 0 {
return nil, fmt.Errorf("at least one query is required")
return nil, errors.New("at least one query is required")
}
maxResults := 5

View File

@@ -6,6 +6,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@@ -36,7 +37,7 @@ func (w *WebFetch) Description() string {
return "Crawl and extract text content from web pages"
}
func (g *WebFetch) Schema() map[string]any {
func (w *WebFetch) Schema() map[string]any {
schemaBytes := []byte(`{
"type": "object",
"properties": {
@@ -61,11 +62,11 @@ func (w *WebFetch) Prompt() string {
func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
urlRaw, ok := args["url"]
if !ok {
return nil, "", fmt.Errorf("url parameter is required")
return nil, "", errors.New("url parameter is required")
}
urlStr, ok := urlRaw.(string)
if !ok || strings.TrimSpace(urlStr) == "" {
return nil, "", fmt.Errorf("url must be a non-empty string")
return nil, "", errors.New("url must be a non-empty string")
}
result, err := performWebFetch(ctx, urlStr)

View File

@@ -6,6 +6,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
@@ -45,7 +46,7 @@ func (w *WebSearch) Prompt() string {
return ""
}
func (g *WebSearch) Schema() map[string]any {
func (w *WebSearch) Schema() map[string]any {
schemaBytes := []byte(`{
"type": "object",
"properties": {
@@ -71,12 +72,12 @@ func (g *WebSearch) Schema() map[string]any {
func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
rawQuery, ok := args["query"]
if !ok {
return nil, "", fmt.Errorf("query parameter is required")
return nil, "", errors.New("query parameter is required")
}
queryStr, ok := rawQuery.(string)
if !ok || strings.TrimSpace(queryStr) == "" {
return nil, "", fmt.Errorf("query must be a non-empty string")
return nil, "", errors.New("query must be a non-empty string")
}
maxResults := 5

View File

@@ -19,10 +19,12 @@ import (
// Errors wrapping Found should provide additional context, e.g.
// fmt.Errorf("%w: %s", not.Found, key)
//
//nolint:staticcheck
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Found.
var Found = errors.New("not found")
// Available is an error that indicates that a value is not available.
//
//nolint:staticcheck
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Available.
var Available = errors.New("not available")

View File

@@ -4,6 +4,7 @@ package not
import (
"fmt"
"strings"
)
type ValidError struct {
@@ -44,12 +45,12 @@ func (b Valids) Error() string {
return ""
}
var result string
var sb strings.Builder
for i, err := range b {
if i > 0 {
result += "; "
sb.WriteString("; ")
}
result += err.Error()
sb.WriteString(err.Error())
}
return result
return sb.String()
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"path/filepath"
"slices"
"strconv"
"strings"
"unicode/utf8"
@@ -73,7 +74,7 @@ func extractPDFText(data []byte) (string, error) {
if strings.TrimSpace(text) != "" {
if textBuilder.Len() > 0 {
textBuilder.WriteString("\n\n--- Page ")
textBuilder.WriteString(fmt.Sprintf("%d", i))
textBuilder.WriteString(strconv.Itoa(i))
textBuilder.WriteString(" ---\n")
}
textBuilder.WriteString(text)

View File

@@ -194,7 +194,7 @@ func (s *Server) Handler() http.Handler {
log := s.log()
level := slog.LevelInfo
start := time.Now()
requestID := fmt.Sprintf("%d", time.Now().UnixNano())
requestID := strconv.FormatInt(time.Now().UnixNano(), 10)
defer func() {
p := recover()
@@ -204,7 +204,7 @@ func (s *Server) Handler() http.Handler {
// Handle panic with user-friendly error
if !sw.Written() {
s.handleError(sw, fmt.Errorf("internal server error"))
s.handleError(sw, errors.New("internal server error"))
}
}
@@ -382,7 +382,7 @@ func waitForServer(ctx context.Context) error {
break
}
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready")
return errors.New("timeout waiting for Ollama server to be ready")
}
time.Sleep(10 * time.Millisecond)
}
@@ -455,7 +455,7 @@ func (s *Server) checkModelUpstream(ctx context.Context, modelName string, timeo
digest := resp.Header.Get("ollama-content-digest")
if digest == "" {
return "", 0, fmt.Errorf("no digest header found")
return "", 0, errors.New("no digest header found")
}
var pushTime int64
@@ -598,12 +598,12 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
}
if req.Model == "" {
return fmt.Errorf("empty model")
return errors.New("empty model")
}
// Don't allow empty messages unless forceUpdate is true
if req.Prompt == "" && !req.ForceUpdate {
return fmt.Errorf("empty message")
return errors.New("empty message")
}
if createdChat {
@@ -942,7 +942,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
} else {
onlyStandalone := true
for _, tc := range res.Message.ToolCalls {
if !(tc.Function.Name == "web_search" || tc.Function.Name == "web_fetch") {
if tc.Function.Name != "web_search" && tc.Function.Name != "web_fetch" {
onlyStandalone = false
break
}
@@ -1194,7 +1194,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
return errors.New("chat ID is required")
}
chat, err := s.Store.Chat(cid)
@@ -1252,7 +1252,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
return errors.New("chat ID is required")
}
var req struct {
@@ -1283,7 +1283,7 @@ func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
return errors.New("chat ID is required")
}
// Check if the chat exists (no need to load attachments)
@@ -1291,7 +1291,7 @@ func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
if err != nil {
if errors.Is(err, not.Found) {
w.WriteHeader(http.StatusNotFound)
return fmt.Errorf("chat not found")
return errors.New("chat not found")
}
return fmt.Errorf("failed to get chat: %w", err)
}
@@ -1592,7 +1592,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return fmt.Errorf("method not allowed")
return errors.New("method not allowed")
}
var req struct {
@@ -1603,7 +1603,7 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
}
if req.Model == "" {
return fmt.Errorf("model is required")
return errors.New("model is required")
}
digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second)
@@ -1730,8 +1730,8 @@ func supportsWebSearchTools(model string) bool {
// buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
var msgs []api.Message
for _, m := range chat.Messages {
msgs := make([]api.Message, len(chat.Messages))
for i, m := range chat.Messages {
// Skip empty messages if present
if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 {
continue
@@ -1789,7 +1789,7 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
s.log().Debug("unknown message role", "role", m.Role)
}
msgs = append(msgs, apiMsg)
msgs[i] = apiMsg
}
var thinkValue *api.ThinkValue

View File

@@ -198,7 +198,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(filepath.Dir(stageFilename))
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil {
return fmt.Errorf("create ollama dir %s: %v", filepath.Dir(stageFilename), err)
return fmt.Errorf("create ollama dir %s: %w", filepath.Dir(stageFilename), err)
}
}
@@ -218,7 +218,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
if err := VerifyDownload(); err != nil {
_ = os.Remove(stageFilename)
return fmt.Errorf("%s - %s", resp.Request.URL.String(), err)
return fmt.Errorf("%s - %w", resp.Request.URL.String(), err)
}
UpdateDownloaded = true
return nil

View File

@@ -92,7 +92,7 @@ func DoUpgrade(interactive bool) error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
return errors.New("failed to lookup downloads")
}
slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile)
@@ -107,7 +107,7 @@ func DoUpgrade(interactive bool) error {
// Verify old doesn't exist yet
if _, err := os.Stat(contentsOldName); err == nil {
slog.Error("prior upgrade failed", "backup", contentsOldName)
return fmt.Errorf("prior upgrade failed - please upgrade manually by installing the bundle")
return errors.New("prior upgrade failed - please upgrade manually by installing the bundle")
}
if err := os.MkdirAll(appBackupDir, 0o755); err != nil {
return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err)
@@ -133,7 +133,7 @@ func DoUpgrade(interactive bool) error {
return err
}
if !chownWithAuthorization(u.Username) {
return fmt.Errorf("unable to change permissions to complete upgrade")
return errors.New("unable to change permissions to complete upgrade")
}
if err := os.Rename(BundlePath, appBackup); err != nil {
return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err)
@@ -264,7 +264,7 @@ func DoPostUpgradeCleanup() error {
func verifyDownload() error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
return errors.New("failed to lookup downloads")
}
slog.Debug("verifying update", "bundle", bundle)
@@ -338,7 +338,7 @@ func verifyDownload() error {
}
if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil {
return fmt.Errorf("signature verification failed: %s", err)
return fmt.Errorf("signature verification failed: %w", err)
}
return nil
}
@@ -347,11 +347,11 @@ func verifyDownload() error {
func DoUpgradeAtStartup() error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
return errors.New("failed to lookup downloads")
}
if BundlePath == "" {
return fmt.Errorf("unable to upgrade at startup, app in development mode")
return errors.New("unable to upgrade at startup, app in development mode")
}
// [Re]verify before proceeding

View File

@@ -22,9 +22,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
// TODO - wire up the redirects to mimic real behavior
} else {
slog.Debug("unexpected request", "url", r.URL)
@@ -67,17 +65,16 @@ func TestBackgoundChecker(t *testing.T) {
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
switch r.URL.Path {
case "/update.json":
fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
// TODO - wire up the redirects to mimic real behavior
} else if r.URL.Path == "/9.9.9/"+Installer {
case "/9.9.9/" + Installer:
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
} else {
default:
slog.Debug("unexpected request", "url", r.URL)
}
}))

View File

@@ -149,7 +149,7 @@ func BenchmarkChat(fOpt flagOptions) error {
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
options := make(map[string]any)
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}

View File

@@ -442,7 +442,7 @@ func TestReadImage_FileNotFound(t *testing.T) {
func TestOptionsMapCreation(t *testing.T) {
fOpt := createTestFlagOptions()
options := make(map[string]interface{})
options := make(map[string]any)
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}

View File

@@ -11,6 +11,7 @@ import (
"fmt"
"io"
"log"
"maps"
"math"
"net"
"net/http"
@@ -203,7 +204,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if err := client.Create(cmd.Context(), req, fn); err != nil {
if strings.Contains(err.Error(), "path or Modelfile are required") {
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client")
return errors.New("the ollama server must be updated to use `ollama create` with this client")
}
return err
}
@@ -990,7 +991,7 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
v = strconv.FormatBool(vData)
case string:
v = vData
case float64:
@@ -1204,9 +1205,7 @@ func (r runOptions) Copy() runOptions {
var opts map[string]any
if r.Options != nil {
opts = make(map[string]any, len(r.Options))
for k, v := range r.Options {
opts[k] = v
}
maps.Copy(opts, r.Options)
}
var think *api.ThinkValue
@@ -1330,12 +1329,12 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
cancel()
}()
var state *displayResponseState = &displayResponseState{}
state := &displayResponseState{}
var thinkingContent strings.Builder
var latest api.ChatResponse
var fullResponse strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
thinkTagOpened := false
thinkTagClosed := false
role := "assistant"
@@ -1430,7 +1429,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
@@ -1463,10 +1462,10 @@ func generate(cmd *cobra.Command, opts runOptions) error {
cancel()
}()
var state *displayResponseState = &displayResponseState{}
state := &displayResponseState{}
var thinkingContent strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
thinkTagOpened := false
thinkTagClosed := false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
@@ -1634,7 +1633,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := client.Heartbeat(cmd.Context()); err != nil {
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
if !strings.Contains(err.Error(), " refused") && !strings.Contains(err.Error(), "could not connect") {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
@@ -1952,13 +1951,13 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
}
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
out := ""
var sb strings.Builder
formatExplanation := ""
formatValues := ""
if !plainText {
formatExplanation = readline.ColorGrey + readline.ColorBold
formatValues = readline.ColorDefault
out += formatExplanation
sb.WriteString(formatExplanation)
}
for i, toolCall := range toolCalls {
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
@@ -1966,13 +1965,13 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
return ""
}
if i > 0 {
out += "\n"
sb.WriteString("\n")
}
// all tool calls are unexpected since we don't currently support registering any in the CLI
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
fmt.Fprintf(&sb, " Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
}
if !plainText {
out += readline.ColorDefault
sb.WriteString(readline.ColorDefault)
}
return out
return sb.String()
}

View File

@@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -307,7 +308,7 @@ func TestDeleteHandler(t *testing.T) {
} else {
w.WriteHeader(http.StatusNotFound)
errPayload := `{"error":"model '%s' not found"}`
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
fmt.Fprintf(w, errPayload, req.Name)
}
return
}
@@ -761,8 +762,8 @@ func TestGetModelfileName(t *testing.T) {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
}
if tt.expectedErr != os.ErrNotExist {
if actualErr != tt.expectedErr {
if !errors.Is(tt.expectedErr, os.ErrNotExist) {
if !errors.Is(actualErr, tt.expectedErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
} else {
@@ -924,10 +925,8 @@ func TestPushHandler(t *testing.T) {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
} else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
})
}
@@ -1014,10 +1013,8 @@ func TestListHandler(t *testing.T) {
if got := string(output); got != tt.expectedOutput {
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
} else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
})
}
@@ -1322,8 +1319,8 @@ func TestRunOptions_Copy(t *testing.T) {
// Test 2: Verify all fields are copied correctly
tests := []struct {
name string
got interface{}
want interface{}
got any
want any
}{
{"Model", copied.Model, original.Model},
{"ParentModel", copied.ParentModel, original.ParentModel},

View File

@@ -130,7 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder
var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
thinkExplicitlySet := opts.Think != nil
for {
line, err := scanner.Readline()
@@ -410,7 +410,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified for this model.")
} else {
for _, l := range strings.Split(resp.Parameters, "\n") {
for l := range strings.SplitSeq(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l)
}
}
@@ -576,9 +576,8 @@ func extractFileNames(input string) []string {
func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input)
var imgs []api.ImageData
for _, fp := range filePaths {
imgs := make([]api.ImageData, len(filePaths))
for i, fp := range filePaths {
nfp := normalizeFilePath(fp)
data, err := getImageData(nfp)
if errors.Is(err, os.ErrNotExist) {
@@ -591,7 +590,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
imgs[i] = data
}
return strings.TrimSpace(input), imgs, nil
}

View File

@@ -38,10 +38,10 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
"general.file_type": uint32(1),
"general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre,
"tokenizer.ggml.model": t.Vocabulary.Model,
"tokenizer.ggml.tokens": t.Vocabulary.Tokens,
"tokenizer.ggml.scores": t.Vocabulary.Scores,
"tokenizer.ggml.token_type": t.Vocabulary.Types,
"tokenizer.ggml.model": t.Model,
"tokenizer.ggml.tokens": t.Tokens,
"tokenizer.ggml.scores": t.Scores,
"tokenizer.ggml.token_type": t.Types,
}
if len(t.Merges) > 0 {
@@ -231,20 +231,20 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch {
case vocabSize == 0:
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Tokens))
case vocabSize > len(t.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Tokens))
for i := range vocabSize - len(t.Tokens) {
t.Tokens = append(t.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Scores = append(t.Scores, -1)
t.Types = append(t.Types, tokenTypeUserDefined)
}
case vocabSize < len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
case vocabSize < len(t.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Tokens))
p.VocabSize = uint32(len(t.Tokens))
p.TextModel.VocabSize = uint32(len(t.Tokens))
default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
slog.Debug("vocabulary", "size", len(t.Tokens))
}
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))

View File

@@ -137,7 +137,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
}
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",

View File

@@ -44,14 +44,14 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
}
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
out[i] = &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out

View File

@@ -43,18 +43,18 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
}
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne)
}
out = append(out, &ggml.Tensor{
out[i] = &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out

View File

@@ -22,8 +22,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
}
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
@@ -31,12 +31,12 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out[i] = &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out

View File

@@ -111,7 +111,7 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
name += ".weight"
}
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{

View File

@@ -127,7 +127,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
}
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
out := make([]*ggml.Tensor, 0, len(ts)+1)
if p.RopeScaling.factors != nil {
out = append(out, &ggml.Tensor{
@@ -176,9 +176,9 @@ func (p *llamaModel) Replacements() []string {
}
func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var heads uint32

View File

@@ -30,8 +30,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
}
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
@@ -41,12 +41,12 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack)
}
out = append(out, &ggml.Tensor{
out[i] = &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
})
}
}
return out

View File

@@ -29,15 +29,6 @@ type mistral3Model struct {
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
@@ -70,13 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
@@ -104,9 +90,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
}
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
@@ -114,12 +99,12 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
}
}
out = append(out, &ggml.Tensor{
out[i] = &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
@@ -159,9 +144,9 @@ func (p *mistral3Model) Replacements() []string {
}
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var heads uint32

View File

@@ -49,20 +49,20 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
}
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
out[i] = &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen2Model) Replacements() []string {
func (q *qwen2Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",

View File

@@ -90,9 +90,9 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
return out
}
func (p *qwen25VLModel) Replacements() []string {
func (q *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
q.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",

View File

@@ -54,6 +54,6 @@ func (t torch) Clone() Tensor {
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) {
func (t torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil
}

View File

@@ -82,7 +82,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
content string
}
var ts []t
ts := make([]t, 0, len(atm))
for content, id := range atm {
ts = append(ts, t{id, content})
}

View File

@@ -65,7 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
@@ -99,9 +98,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
@@ -129,20 +125,10 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue
}
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1)
go func(i int) {
@@ -488,16 +474,3 @@ func overrideWarnings() {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -57,13 +57,8 @@ ollama ps
```
<Info>
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -390,4 +385,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -149,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -300,9 +300,9 @@ func (s Tensors) Items(prefix ...string) []*Tensor {
return items
}
func (ts Tensors) GroupLayers() map[string]Layer {
func (s Tensors) GroupLayers() map[string]Layer {
layers := make(map[string]Layer)
for _, t := range ts.items {
for _, t := range s.items {
parts := strings.Split(t.Name, ".")
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
if len(parts) > index+2 {

View File

@@ -5,6 +5,7 @@ import (
"cmp"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
@@ -225,7 +226,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
Name: name,
Kind: kind,
Offset: offset,
Shape: shape[:],
Shape: shape,
}
llm.tensors = append(llm.tensors, &tensor)
@@ -511,7 +512,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture")
if arch == "" {
return fmt.Errorf("architecture not set")
return errors.New("architecture not set")
}
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {

View File

@@ -136,8 +136,8 @@ func (t FileType) Value() uint32 {
return uint32(t)
}
func (ftype FileType) ToTensorType() TensorType {
switch ftype {
func (t FileType) ToTensorType() TensorType {
switch t {
case FileTypeF32:
return TensorTypeF32
case FileTypeF16:
@@ -177,7 +177,7 @@ func (ftype FileType) ToTensorType() TensorType {
case fileTypeMXFP4:
return TensorTypeMXFP4
default:
slog.Warn("unsupported file type", "type", ftype)
slog.Warn("unsupported file type", "type", t)
return 0 // F32
}
}

View File

@@ -11,7 +11,7 @@ type KeyValue struct {
}
func (kv KeyValue) Valid() bool {
return kv.Key != "" && kv.Value.value != nil
return kv.Key != "" && kv.value != nil
}
type Value struct {

View File

@@ -200,9 +200,7 @@ func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
before := raw[:channelIndex]
after := raw[channelIndex+len("<|channel|>"):]
// the channel name is `after` all the way up to the first (if any) whitespace character
idx := strings.IndexFunc(after, func(r rune) bool {
return unicode.IsSpace(r)
})
idx := strings.IndexFunc(after, unicode.IsSpace)
if idx == -1 {
idx = len(after)
}
@@ -319,11 +317,12 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
}
case HarmonyEventContentEmitted:
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
if h.state == harmonyMessageState_Normal {
switch h.state {
case harmonyMessageState_Normal:
contentSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_Thinking {
case harmonyMessageState_Thinking:
thinkingSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_ToolCalling {
case harmonyMessageState_ToolCalling:
toolContentSb.WriteString(event.Content)
}
case HarmonyEventMessageEnd:

View File

@@ -33,9 +33,6 @@ func TestVisionModels(t *testing.T) {
// Qwen 3 VL mixture of experts
model: "qwen3-vl:30b",
},
{
model: "ministral-3",
},
}
for _, v := range testCases {

View File

@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
"gemma3n:e2b",
@@ -168,7 +167,6 @@ var (
"medllama2",
"megadolphin",
"minicpm-v",
"ministral-3",
"mistral-large",
"mistral-nemo",
"mistral-openorca",
@@ -272,7 +270,6 @@ var (
"mistral",
"qwen2.5",
"qwen2",
"ministral-3",
"mistral-nemo",
"mistral-small",
"mixtral:8x22b",

View File

@@ -263,9 +263,9 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.use_mmap = C.bool(params.UseMmap)
cparams.vocab_only = C.bool(params.VocabOnly)
var devices []C.ggml_backend_dev_t
for _, llamaID := range params.Devices {
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
devices := make([]C.ggml_backend_dev_t, len(params.Devices))
for i, llamaID := range params.Devices {
devices[i] = C.ggml_backend_dev_get(C.size_t(llamaID))
}
if len(devices) > 0 {
devices = append(devices, C.ggml_backend_dev_t(C.NULL))

View File

@@ -250,7 +250,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
err := fmt.Errorf("error starting runner: %v %s", err, msg)
err := fmt.Errorf("error starting runner: %w %s", err, msg)
if llamaModel != nil {
llama.FreeModel(llamaModel)
}
@@ -846,14 +846,7 @@ nextOperation:
func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
devices := []ml.DeviceID{}
for _, layer := range gpuLayers {
new := true
for _, ID := range devices {
if layer.DeviceID == ID {
new = false
break
}
}
if new {
if !slices.Contains(devices, layer.DeviceID) {
devices = append(devices, layer.DeviceID)
}
}
@@ -874,7 +867,7 @@ func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.Devic
}}
}
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
err := s.verifyLayout(systemInfo, systemGPUs, memory, requireFull, gpuLayers, layers)
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
if err != nil {
return nil, err
}
@@ -943,7 +936,7 @@ func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMe
}
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
// These sizes will only increase as we go through additional iterations and get additional information.
cpuSize := memory.InputWeights + memory.CPU.Graph
var vramSize uint64
@@ -970,8 +963,8 @@ nextLayer:
}
if requireFull {
if len(systemGPUs) > 0 && gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
slog.Info("model requires more gpu memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
return ErrLoadRequiredFull
}
@@ -989,16 +982,14 @@ nextLayer:
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
}
} else {
if vramSize > systemInfo.TotalMemory {
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
s.options.NumGPU = 0
gpuLayers = ml.GPULayersList{}
}
} else if vramSize > systemInfo.TotalMemory {
// disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system
s.options.NumGPU = 0
gpuLayers = ml.GPULayersList{}
}
if len(systemGPUs) > 0 && gpuLayers.Sum() == 0 {
if gpuLayers.Sum() == 0 {
slog.Debug("insufficient VRAM to load any model layers")
}
@@ -1218,7 +1209,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
if err != nil {
return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
return ServerStatusError, fmt.Errorf("error creating GET request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
@@ -1481,7 +1472,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
return errors.New("invalid JSON schema in format")
}
req.Grammar = string(g)
}
@@ -1521,13 +1512,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc.SetEscapeHTML(false)
if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to marshal data: %v", err)
return fmt.Errorf("failed to marshal data: %w", err)
}
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
return fmt.Errorf("error creating POST request: %v", err)
return fmt.Errorf("error creating POST request: %w", err)
}
serverReq.Header.Set("Content-Type", "application/json")
@@ -1576,7 +1567,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
var c CompletionResponse
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
return fmt.Errorf("error unmarshalling llm prediction response: %w", err)
}
switch {
case strings.TrimSpace(c.Content) == lastToken:
@@ -1618,7 +1609,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("an error was encountered while running the model: %s", msg)
}
return fmt.Errorf("error reading llm response: %v", err)
return fmt.Errorf("error reading llm response: %w", err)
}
return nil
@@ -1693,7 +1684,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
defer s.llamaModelLock.Unlock()
if s.llamaModel == nil {
return nil, fmt.Errorf("no tokenizer configured")
return nil, errors.New("no tokenizer configured")
}
return s.llamaModel.Tokenize(content, false, true)
@@ -1718,15 +1709,15 @@ func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
defer s.llamaModelLock.Unlock()
if s.llamaModel == nil {
return "", fmt.Errorf("no tokenizer configured")
return "", errors.New("no tokenizer configured")
}
var resp string
var sb strings.Builder
for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token)
sb.WriteString(s.llamaModel.TokenToPiece(token))
}
return resp, nil
return sb.String(), nil
}
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {

View File

@@ -26,11 +26,10 @@ func TestLLMServerFitGPU(t *testing.T) {
expectedErr error
}{
{
name: "No GPU",
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{},
requireFull: true, // Should not try to evict even though we can't load any layers
name: "No GPU",
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
numGPU: -1,
expected: ml.GPULayersList{},
},
{
name: "Full single GPU",
@@ -210,7 +209,7 @@ func TestLLMServerFitGPU(t *testing.T) {
}
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
if err != tt.expectedErr {
if !errors.Is(err, tt.expectedErr) {
t.Fatalf("fitGPU returned error: %v", err)
}
if gpuLayers.Hash() != tt.expected.Hash() {

View File

@@ -84,7 +84,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil {
return 0, err
}
@@ -98,7 +98,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil {
return 0, err
}
@@ -123,7 +123,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
code := w.Status()
if code != http.StatusOK {
return w.writeError(data)
}
@@ -150,7 +150,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil {
return 0, err
}
@@ -164,7 +164,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil {
return 0, err
}
@@ -189,7 +189,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
code := w.Status()
if code != http.StatusOK {
return w.writeError(data)
}
@@ -214,7 +214,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
code := w.Status()
if code != http.StatusOK {
return w.writeError(data)
}
@@ -240,7 +240,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
code := w.Status()
if code != http.StatusOK {
return w.writeError(data)
}
@@ -265,7 +265,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
code := w.Status()
if code != http.StatusOK {
return w.writeError(data)
}

View File

@@ -68,7 +68,7 @@ func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) {
switch tc.expectType {
case "array":
if _, ok := result.Data[0].Embedding.([]interface{}); !ok {
if _, ok := result.Data[0].Embedding.([]any); !ok {
t.Errorf("expected array, got %T", result.Data[0].Embedding)
}
case "string":
@@ -210,10 +210,8 @@ func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) {
if !strings.Contains(errResp.Error.Message, "encoding_format") {
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
}
} else {
if resp.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
} else if resp.Code != http.StatusOK {
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
})
}

View File

@@ -845,19 +845,17 @@ func TestListMiddleware(t *testing.T) {
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
var want, got map[string]any
if err := json.Unmarshal([]byte(tc.resp), &want); err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
if err := json.Unmarshal(resp.Body.Bytes(), &got); err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("response does not match (-want +got):\n%s", diff)
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"math"
"slices"
@@ -92,7 +93,7 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
return backend(modelPath, params)
}
return nil, fmt.Errorf("unsupported backend")
return nil, errors.New("unsupported backend")
}
type Context interface {

View File

@@ -178,14 +178,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
requiredMemory.CPU.Cache = make([]uint64, blocks+1)
// create list of buffer types for each gpu
var gpuDeviceBufferTypes []deviceBufferType
gpuDeviceBufferTypes := make([]deviceBufferType, len(gpus))
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
for i, d := range gpus {
bt := C.ggml_backend_dev_buffer_type(d)
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
gpuDeviceBufferTypes[i] = deviceBufferType{
d: d,
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
})
}
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
@@ -354,8 +354,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t)
// create backends and buffer types used for the compute graph scheduler
var schedBackends []C.ggml_backend_t
var schedBufts []C.ggml_backend_buffer_type_t
schedBackends := make([]C.ggml_backend_t, 0, len(cpus)+len(accels)+len(gpus))
schedBufts := make([]C.ggml_backend_buffer_type_t, 0, len(cpus)+len(accels)+len(gpus))
for _, d := range append(gpus, append(accels, cpus...)...) {
b := backends[d]
bt := C.ggml_backend_get_default_buffer_type(b)

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"hash/maphash"
"io"
@@ -218,7 +219,7 @@ type BackendMemory struct {
}
func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr
attrs := make([]slog.Attr, 0, 2+len(m.GPUs))
if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
}
@@ -414,14 +415,7 @@ func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath}
for _, gpu := range l {
for _, dir := range gpu.LibraryPath {
needed := true
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
if !slices.Contains(gpuLibs, dir) {
gpuLibs = append(gpuLibs, dir)
}
}
@@ -437,15 +431,15 @@ const (
DuplicateDevice // The same physical device but different library/backend (overlapping device)
)
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID {
func (d DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if d.PCIID != b.PCIID {
return UniqueDevice
}
// If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID {
if d.PCIID == "" && d.DeviceID != b.DeviceID {
return UniqueDevice
}
if a.Library == b.Library {
if d.Library == b.Library {
return SameBackendDevice
}
return DuplicateDevice
@@ -453,8 +447,8 @@ func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
// For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1]
func (d DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := d.LibraryPath[len(d.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib {
return false
@@ -481,7 +475,7 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && (gpu.ComputeMajor != 7 || gpu.ComputeMinor != 2)) ||
gpu.Library == "ROCm" ||
gpu.Library == "Vulkan"
@@ -509,9 +503,11 @@ func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string {
// to crash at inference time and requires deeper validation before we include
// it in the supported devices list.
func (d DeviceInfo) NeedsInitValidation() bool {
// ROCm: rocblas will crash on unsupported devices.
// CUDA: verify CC is supported by the version of the library
return d.Library == "ROCm" || d.Library == "CUDA"
// At this time the only library we know needs a 2nd pass is ROCm since
// rocblas will crash on unsupported devices. We want to find those crashes
// during bootstrap discovery so we can eliminate those GPUs before the user
// tries to run inference on them
return d.Library == "ROCm"
}
// Set the init validation environment variable
@@ -547,12 +543,12 @@ func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) {
}
v, existing := env[envVar]
if existing {
v = v + ","
v += ","
}
if d.FilterID != "" {
v = v + d.FilterID
v += d.FilterID
} else {
v = v + d.ID
v += d.ID
}
env[envVar] = v
}
@@ -592,7 +588,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout")
return nil, errors.New("failed to finish discovery before timeout")
case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil {
@@ -604,7 +600,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
if err != nil {
// slog.Warn("failed to send request", "error", err)
if runner.HasExited() {
return nil, fmt.Errorf("runner crashed")
return nil, errors.New("runner crashed")
}
continue
}
@@ -612,7 +608,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
return nil, errors.New("llamarunner free vram reporting not supported")
}
body, err := io.ReadAll(resp.Body)

View File

@@ -143,9 +143,9 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
r += 0x0100
case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2
r += 0x00a2
}
sb.WriteRune(r)
@@ -264,9 +264,9 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
r -= 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
r -= 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8

View File

@@ -146,7 +146,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
arch += "_embed"
}
f, ok := models[arch]
@@ -175,9 +175,10 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
tagsCopy = append(tagsCopy, parseTag(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
switch {
case tt == reflect.TypeFor[Base]():
vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
case tt == reflect.TypeFor[ml.Tensor]():
var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 {
@@ -217,9 +218,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
break
}
}
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
case tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface:
setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
case tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array:
for i := range vv.Len() {
vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {

View File

@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.attnKeyLen, m.ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
}
type MLP struct {
@@ -178,10 +178,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
m.Options.largeModelScaling = true
m.largeModelScaling = true
}
for i, layer := range m.Layers {
@@ -202,9 +202,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
return hiddenState.Scale(ctx, float64(m.finalLogitSoftcap)), nil
}
func init() {

View File

@@ -96,15 +96,15 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err
}
f32s, err := m.ImageProcessor.ProcessImage(image)
f32s, err := m.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues := ctx.Input().FromFloats(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.imageSize,
m.imageSize,
m.numChannels,
)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)

View File

@@ -111,12 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextConfig.ropeLocalBase
ropeBase := m.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase
ropeBase = m.ropeGlobalBase
}
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
return fast.RoPE(ctx, key, shift, m.attnKeyLen, ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
@@ -166,7 +166,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
// set image embeddings
var except []int

View File

@@ -53,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
m.Cache = kvcache.NewCausalCache(m.Shift)
return m, nil
}
@@ -109,12 +109,12 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err
}
f32s, size, err := m.ImageProcessor.ProcessImage(image)
f32s, size, err := m.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.numChannels)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
positionsScale := m.getScale(ctx, batch.Positions)
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -16,8 +16,6 @@ type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int
ropeScalingBeta float32
}
type TextModel struct {
@@ -36,7 +34,7 @@ type SelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
@@ -51,10 +49,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if opts.ropeOrigPosEmbeddings > 0 {
q = q.Mul(ctx, positionsScale)
}
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
@@ -82,11 +76,11 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
@@ -103,7 +97,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
// image embeddings
@@ -120,36 +114,25 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, o
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
posScale := make([]float32, len(positions))
for n, pos := range positions {
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
}
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
},
}
}

View File

@@ -133,7 +133,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.eps)
// Prepare position IDs for 2D rope
positions := make([]int32, numPatches)

View File

@@ -54,7 +54,7 @@ func New(c fs.Config) (model.Model, error) {
encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.Shift))
return &m, nil
}
@@ -69,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err
}
f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
f32s, ratio, err := m.ProcessImage(image)
if err != nil {
return nil, err
}

View File

@@ -223,8 +223,8 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, cros
}
func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") {
decoderLayers := make([]TextDecoderLayer, c.Uint("block_count"))
for i := range decoderLayers {
var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
@@ -232,7 +232,7 @@ func newTextModel(c fs.Config) *TextModel {
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
}
decoderLayers = append(decoderLayers, textDecoderLayer)
decoderLayers[i] = textDecoderLayer
}
return &TextModel{

View File

@@ -2,6 +2,7 @@ package qwen2
import (
"cmp"
"errors"
"fmt"
"math"
"strings"
@@ -130,7 +131,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
func New(c fs.Config) (model.Model, error) {
// This model currently only supports the gpt2 tokenizer
if c.String("tokenizer.ggml.model") == "llama" {
return nil, fmt.Errorf("unsupported tokenizer: llama")
return nil, errors.New("unsupported tokenizer: llama")
}
// detect library/qwen model(s) which are incompatible
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {

View File

@@ -48,7 +48,7 @@ func New(c fs.Config) (model.Model, error) {
ImageProcessor: newImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
m.Cache = kvcache.NewCausalCache(m.Shift)
return m, nil
}
@@ -59,14 +59,13 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
return nil, nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
f32s, grid, err := m.ProcessImage(image)
if err != nil {
return nil, nil, err
}
// Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)

View File

@@ -228,7 +228,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.numHeads)
// Apply encoder layers
for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, int32(i)) {

View File

@@ -107,7 +107,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
return nil, nil, fmt.Errorf("failed to create patches: %w", err)
}
// Return patches and grid dimensions

View File

@@ -203,7 +203,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
var _ model.Model = (*Model)(nil)

View File

@@ -111,7 +111,7 @@ func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tenso
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
return nil, nil, fmt.Errorf("failed to create patches: %w", err)
}
patchDim := p.numChannels * p.temporalPatchSize *

View File

@@ -1,136 +0,0 @@
package parsers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type ministralParserState int
const (
ministralCollectingContent = iota
ministralCollectingThinkingContent
ministralCollectingToolName
ministralCollectingToolArgs
)
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
currentTool *api.Tool
}
func (p *MinistralParser) HasToolSupport() bool {
return true
}
func (p *MinistralParser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
if !p.HasThinkingSupport() {
p.state = ministralCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = ministralCollectingContent
return
}
p.state = ministralCollectingThinkingContent
}
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.setInitialState(lastMessage)
return tools
}
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
for i := range tools {
if tools[i].Function.Name == n {
return &tools[i], nil
}
}
return nil, fmt.Errorf("tool '%s' not found", n)
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
}
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
}
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var data map[string]any
if err := json.Unmarshal([]byte(before), &data); err != nil {
// todo - throw a better error
return "", "", calls, err
}
p.state = ministralCollectingContent
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name,
Arguments: api.ToolCallFunctionArguments(data),
},
}
calls = append(calls, call)
return "", "", calls, nil
}
return "", "", calls, nil
}
return p.buffer.String(), thinking, calls, nil
}

View File

@@ -1,9 +1,6 @@
package parsers
import (
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
@@ -41,17 +38,16 @@ func ParserForName(name string) Parser {
if parser, ok := registry.constructors[name]; ok {
return parser()
}
var p Parser
switch name {
case "qwen3-coder":
p = &Qwen3CoderParser{}
parser := &Qwen3CoderParser{}
return parser
case "qwen3-vl-instruct":
p = &Qwen3VLParser{hasThinkingSupport: false}
parser := &Qwen3VLParser{hasThinkingSupport: false}
return parser
case "qwen3-vl-thinking":
p = &Qwen3VLParser{hasThinkingSupport: true}
case "ministral":
p = &MinistralParser{hasThinkingSupport: false}
parser := &Qwen3VLParser{hasThinkingSupport: true}
return parser
case "passthrough":
return &PassthroughParser{}
case "harmony":
@@ -61,7 +57,6 @@ func ParserForName(name string) Parser {
default:
return nil
}
return p
}
type PassthroughParser struct{}
@@ -81,20 +76,3 @@ func (p *PassthroughParser) HasToolSupport() bool {
func (p *PassthroughParser) HasThinkingSupport() bool {
return false
}
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(sb.String(), tag, 2)
if len(split) == 1 {
sb.Reset()
return split[0], ""
}
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
sb.Reset()
sb.WriteString(after)
return before, after // return events
}

View File

@@ -1,7 +1,6 @@
package parsers
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
@@ -96,164 +95,3 @@ func TestUnknownParserReturnsNil(t *testing.T) {
t.Error("expected nil for unknown parser")
}
}
func TestSplitAtTag(t *testing.T) {
tests := []struct {
name string
input string
tag string
trimAfter bool
wantBefore string
wantAfter string
wantSB string // expected content of strings.Builder after operation
}{
{
name: "basic split with trimAfter true",
input: "hello <!-- split --> world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "world",
wantSB: "world",
},
{
name: "basic split with trimAfter false",
input: "hello <!-- split --> world",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: " world",
wantSB: " world",
},
{
name: "tag at beginning with trimAfter true",
input: "<!-- split -->world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "world",
wantSB: "world",
},
{
name: "tag at beginning with trimAfter false",
input: "<!-- split --> world",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "",
wantAfter: " world",
wantSB: " world",
},
{
name: "tag at end with trimAfter true",
input: "hello <!-- split -->",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "tag at end with trimAfter false",
input: "hello <!-- split -->",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "multiple tags splits at first occurrence",
input: "hello <!-- split --> world <!-- split --> end",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "world <!-- split --> end",
wantSB: "world <!-- split --> end",
},
{
name: "tag not present",
input: "hello world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello world",
wantAfter: "",
wantSB: "",
},
{
name: "empty input",
input: "",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "",
wantSB: "",
},
{
name: "only whitespace before tag",
input: " \t\n<!-- split -->world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "world",
wantSB: "world",
},
{
name: "only whitespace after tag with trimAfter true",
input: "hello<!-- split --> \t\n",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "only whitespace after tag with trimAfter false",
input: "hello<!-- split --> \t\n",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: " \t\n",
wantSB: " \t\n",
},
{
name: "complex whitespace trimming",
input: " hello \t\n <!-- split --> \n\t world ",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: " hello",
wantAfter: "world ",
wantSB: "world ",
},
{
name: "tag with special characters",
input: "text <tag attr=\"value\"> more text",
tag: "<tag attr=\"value\">",
trimAfter: true,
wantBefore: "text",
wantAfter: "more text",
wantSB: "more text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sb := &strings.Builder{}
sb.WriteString(tt.input)
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
// Check return values
if before != tt.wantBefore {
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
}
if after != tt.wantAfter {
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
}
// Check strings.Builder state
if sb.String() != tt.wantSB {
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
}
})
}
}

View File

@@ -70,6 +70,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
@@ -80,7 +81,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
calls = append(calls, toolCall)
toolCalls = append(toolCalls, toolCall)
case qwenEventThinkingContent:
thinkingSb.WriteString(event.content)
case qwenEventContent:
@@ -90,7 +91,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
}
}
return contentSb.String(), thinkingSb.String(), calls, nil
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
@@ -112,6 +113,19 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
return all
}
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after // return events
}
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
@@ -130,7 +144,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
case CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
// events = emitContentBeforeTag(p, events, toolOpenTag)
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
before, _ := splitAtTag(p, toolOpenTag, false)
if len(before) > 0 {
events = append(events, qwenEventContent{content: before})
}
@@ -181,7 +195,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
}
case CollectingThinkingContent:
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwenEventThinkingContent{content: thinking})
}

View File

@@ -98,7 +98,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
if multiStepTool && message.Role == "user" {
// Check if content starts with <tool_response> and ends with </tool_response>
content := r.renderContent(message)
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
if !strings.HasPrefix(content, "<tool_response>") || !strings.HasSuffix(content, "</tool_response>") {
multiStepTool = false
lastQueryIndex = i
}

View File

@@ -205,12 +205,12 @@ func (q queue) Less(i, j int) bool {
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) {
func (q *queue) Push(x any) {
item := x.(*candidate)
*q = append(*q, item)
}
func (q *queue) Pop() interface{} {
func (q *queue) Pop() any {
old := *q
n := len(old)
item := old[n-1]
@@ -231,7 +231,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err)
return "", fmt.Errorf("failed to parse hex byte: %w", err)
}
if err := sb.WriteByte(byte(byteVal)); err != nil {

View File

@@ -232,9 +232,9 @@ func NewError(code int, message string) ErrorResponse {
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
@@ -326,9 +326,9 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
}
}
@@ -377,20 +377,19 @@ func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
})
}
return ListCompletion{
Object: "list",
Data: data,
c := ListCompletion{Object: "list"}
if len(r.Models) > 0 {
c.Data = make([]Model, len(r.Models))
for i, m := range r.Models {
c.Data[i] = Model{
Id: m.Name,
Object: "model",
Created: m.ModifiedAt.Unix(),
OwnedBy: model.ParseName(m.Name).Namespace,
}
}
}
return c
}
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
@@ -487,19 +486,14 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
}
types := []string{"jpeg", "jpg", "png", "webp"}
valid := false
// support blank mime type to match api/chat taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
valid = true
}
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
url, valid := strings.CutPrefix(url, "data:;base64,")
if !valid {
for _, t := range []string{"jpeg", "jpg", "png", "webp"} {
prefix := "data:image/" + t + ";base64,"
url, valid = strings.CutPrefix(url, prefix)
if valid {
break
}
}
}

View File

@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"maps"
"net/http"
"os"
"os/user"
@@ -78,9 +79,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
if req.Files == nil {
req.Files = digestMap
} else {
for k, v := range digestMap {
req.Files[k] = v
}
maps.Copy(req.Files, digestMap)
}
case "adapter":
path, err := expandPath(c.Args, relativeDir)
@@ -371,7 +370,7 @@ func (e *ParserError) Error() string {
func ParseFile(r io.Reader) (*Modelfile, error) {
var cmd Command
var curr state
var currLine int = 1
currLine := 1
var b bytes.Buffer
var role string

View File

@@ -326,17 +326,11 @@ MESSAGE system`,
return
}
switch tt.err.(type) {
case *ParserError:
var pErr *ParserError
if errors.As(err, &pErr) {
// got the correct type of error
return
}
}
if errors.Is(err, tt.err) {
return
} else if pErr := (*ParserError)(nil); errors.As(err, &pErr) {
// got the correct type of error
return
}
t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
@@ -1089,7 +1083,7 @@ func TestFilesForModel(t *testing.T) {
if err == nil {
t.Error("Expected error, but got none")
}
if tt.expectErrType != nil && err != tt.expectErrType {
if tt.expectErrType != nil && !errors.Is(err, tt.expectErrType) {
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
}
return

View File

@@ -3,6 +3,7 @@ package readline
import (
"fmt"
"os"
"strings"
"github.com/emirpasic/gods/v2/lists/arraylist"
"github.com/mattn/go-runewidth"
@@ -297,7 +298,7 @@ func (b *Buffer) drawRemaining() {
remaining := (remainingText[len(currLine):])
var totalLines int
var displayLength int
var lineLength int = currLineSpace
lineLength := currLineSpace
for _, c := range remaining {
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
@@ -515,13 +516,13 @@ func (b *Buffer) StringN(n int) string {
}
func (b *Buffer) StringNM(n, m int) string {
var s string
var sb strings.Builder
if m == 0 {
m = b.Buf.Size()
}
for cnt := n; cnt < m; cnt++ {
c, _ := b.Buf.Get(cnt)
s += string(c)
sb.WriteRune(c)
}
return s
return sb.String()
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"strings"
)
type Prompt struct {
@@ -124,18 +125,19 @@ func (i *Instance) Readline() (string, error) {
case KeyRight:
buf.MoveRight()
case CharBracketedPaste:
var code string
var code strings.Builder
for range 3 {
r, err = i.Terminal.Read()
if err != nil {
return "", io.EOF
}
code += string(r)
code.WriteRune(r)
}
if code == CharBracketedPasteStart {
switch code.String() {
case CharBracketedPasteStart:
i.Pasting = true
} else if code == CharBracketedPasteEnd {
case CharBracketedPasteEnd:
i.Pasting = false
}
case KeyDel:

View File

@@ -459,10 +459,7 @@ func TestLogprobsWithStopSequences(t *testing.T) {
origLogprobsLen := len(logprobs)
numTokensRemoved := origLen - newLen
newLogprobsLen := origLogprobsLen - numTokensRemoved
if newLogprobsLen < 0 {
newLogprobsLen = 0
}
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
logprobs = logprobs[:newLogprobsLen]
// Verify responses were truncated correctly

View File

@@ -39,21 +39,15 @@ func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
result := make([]string, 0, len(pieces))
tokenTruncated := false
start := 0
for _, length := range lengths {
for _, piece := range pieces {
if start >= len(joined) {
break
}
end := start + length
end := start + len(piece)
if end > len(joined) {
end = len(joined)
tokenTruncated = true

View File

@@ -61,7 +61,7 @@ func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []by
return nil, nil
}
if len(data) <= 0 {
if len(data) == 0 {
return nil, errors.New("received zero length image")
}

View File

@@ -1,6 +1,7 @@
package llamarunner
import (
"errors"
"reflect"
"testing"
@@ -18,7 +19,7 @@ func TestImageCache(t *testing.T) {
// Empty cache
result, err := cache.findImage(0x5adb61d31933a946)
if err != errImageNotFound {
if !errors.Is(err, errImageNotFound) {
t.Errorf("found result in empty cache: result %v, err %v", result, err)
}

View File

@@ -577,10 +577,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if seq.logprobs {
origLogprobsLen := len(seq.pendingLogprobs)
numTokensRemoved := origLen - newLen
newLogprobsLen := origLogprobsLen - numTokensRemoved
if newLogprobsLen < 0 {
newLogprobsLen = 0
}
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
}
@@ -998,7 +995,6 @@ func Execute(args []string) error {
log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
return err
}

View File

@@ -2,7 +2,6 @@ package ollamarunner
import (
"errors"
"fmt"
"slices"
"testing"
"time"
@@ -511,7 +510,7 @@ type mockCache struct {
// Implement only the methods needed for the test
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
if m.shouldFail {
return fmt.Errorf("mock cache removal error")
return errors.New("mock cache removal error")
}
return nil
}

View File

@@ -801,10 +801,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq.logprobs {
origLogprobsLen := len(seq.pendingLogprobs)
numTokensRemoved := origLen - newLen
newLogprobsLen := origLogprobsLen - numTokensRemoved
if newLogprobsLen < 0 {
newLogprobsLen = 0
}
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
}
@@ -1242,7 +1239,7 @@ func (s *Server) loadModel() {
s.progress = progress
})
if err != nil {
panic(fmt.Errorf("failed to load model: %v", err))
panic(fmt.Errorf("failed to load model: %w", err))
}
s.status = llm.ServerStatusReady
@@ -1432,7 +1429,6 @@ func Execute(args []string) error {
log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
return err
}

View File

@@ -30,7 +30,7 @@ func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
ts[i].value /= temp
}
}

View File

@@ -33,7 +33,7 @@ func (r registryChallenge) URL() (*url.URL, error) {
values := redirectURL.Query()
values.Add("service", r.Service)
for _, s := range strings.Split(r.Scope, " ") {
for s := range strings.SplitSeq(r.Scope, " ") {
values.Add("scope", s)
}
@@ -57,7 +57,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
}
sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
data := fmt.Appendf(nil, "%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))
headers := make(http.Header)
signature, err := auth.Sign(ctx, data)
@@ -75,7 +75,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
return "", fmt.Errorf("%d: %w", response.StatusCode, err)
}
if response.StatusCode >= http.StatusBadRequest {

View File

@@ -386,7 +386,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
}
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
return nil, fmt.Errorf("%w: %w: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest)
@@ -456,15 +456,15 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
return l.KV(), nil
}
}
return ggml.KV{}, fmt.Errorf("no base model was found")
return ggml.KV{}, errors.New("no base model was found")
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer
for _, layer := range baseLayers {
layers := make([]Layer, len(baseLayers))
for i, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
if quantType != "" && layer.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
want, err := ggml.ParseFileType(quantType)
if err != nil {
return err
@@ -480,13 +480,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
}
}
config.ModelFormat = cmp.Or(config.ModelFormat, layer.GGML.Name())
config.ModelFormat = cmp.Or(config.ModelFormat, layer.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
}
layers = append(layers, layer.Layer)
layers[i] = layer.Layer
}
if r.Template != "" {
@@ -678,10 +678,10 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
func setTemplate(layers []Layer, t string) ([]Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
}
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
}
blob := strings.NewReader(t)

Some files were not shown because too many files have changed in this diff Show More