Compare commits

...

8 Commits

Author SHA1 Message Date
Michael Yang b59053a883 presets 2025-11-19 17:26:18 -08:00
Michael Yang bb93e5afe7 errorlint 2025-11-19 17:26:18 -08:00
Michael Yang 4d24d8a77d gocritic 2025-11-19 17:26:18 -08:00
Michael Yang f01c83ed6d fmt 2025-11-19 17:26:18 -08:00
Michael Yang d3228355be staticcheck 2025-11-19 17:26:18 -08:00
Michael Yang 78a75a30d8 prealloc 2025-11-19 17:26:18 -08:00
Michael Yang 974ae8ef84 perfsprint 2025-11-19 17:26:17 -08:00
Michael Yang efd9f5e67e modernize 2025-11-19 17:26:17 -08:00
108 changed files with 484 additions and 547 deletions

View File

@ -36,6 +36,12 @@ linters:
errcheck: errcheck:
exclude-functions: exclude-functions:
- fmt.Fprintf - fmt.Fprintf
gocritic:
disabled-checks:
# Detects suspicious duplicated sub-expressions.
# Prone to false positives when used on cgo code
# https://github.com/go-critic/go-critic/issues/897#issuecomment-568892104
- dupSubExpr
perfsprint: perfsprint:
strconcat: false strconcat: false
concat-loop: false concat-loop: false
@ -45,24 +51,22 @@ linters:
# Using a deprecated function, variable, constant or field. # Using a deprecated function, variable, constant or field.
# https://staticcheck.dev/docs/checks/#SA1019 # https://staticcheck.dev/docs/checks/#SA1019
- -SA1019 - -SA1019
# Incorrect or missing package comment.
# https://staticcheck.dev/docs/checks/#ST1000
- -ST1000
# Poorly chosen identifier. # Poorly chosen identifier.
# https://staticcheck.dev/docs/checks/#ST1003 # https://staticcheck.dev/docs/checks/#ST1003
- -ST1003 - -ST1003
# The documentation of an exported function should start with the function's name.
# https://staticcheck.dev/docs/checks/#ST1020
- -ST1020
# The documentation of an exported type should start with type's name.
# https://staticcheck.dev/docs/checks/#ST1021
- -ST1021
# The documentation of an exported variable or constant should start with variable's name.
# https://staticcheck.dev/docs/checks/#ST1022
- -ST1022
usestdlibvars: usestdlibvars:
http-method: false http-method: false
http-status-code: false http-status-code: false
exclusions:
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
rules:
- path: _test\.go
linters:
- prealloc
formatters: formatters:
enable: enable:

View File

@ -2,6 +2,7 @@ package api
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -39,7 +40,7 @@ func TestClientFromEnvironment(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value) t.Setenv("OLLAMA_HOST", v.value)
client, err := ClientFromEnvironment() client, err := ClientFromEnvironment()
if err != v.err { if !errors.Is(err, v.err) {
t.Fatalf("expected %s, got %s", v.err, err) t.Fatalf("expected %s, got %s", v.err, err)
} }

View File

@ -2,6 +2,7 @@ package api
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
@ -308,9 +309,9 @@ func (tp ToolProperty) ToTypeScriptType() string {
return mapToTypeScriptType(tp.Type[0]) return mapToTypeScriptType(tp.Type[0])
} }
var types []string types := make([]string, len(tp.Type))
for _, t := range tp.Type { for i, t := range tp.Type {
types = append(types, mapToTypeScriptType(t)) types[i] = mapToTypeScriptType(t)
} }
return strings.Join(types, " | ") return strings.Join(types, " | ")
} }
@ -783,7 +784,7 @@ func (m *Metrics) Summary() {
func (opts *Options) FromMap(m map[string]any) error { func (opts *Options) FromMap(m map[string]any) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
// build map of json struct tags to their types // build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField) jsonOpts := make(map[string]reflect.StructField)
@ -854,8 +855,7 @@ func (opts *Options) FromMap(m map[string]any) error {
} }
field.Set(reflect.ValueOf(slice)) field.Set(reflect.ValueOf(slice))
case reflect.Pointer: case reflect.Pointer:
var b bool if field.Type() == reflect.TypeFor[*bool]() {
if field.Type() == reflect.TypeOf(&b) {
val, ok := val.(bool) val, ok := val.(bool)
if !ok { if !ok {
return fmt.Errorf("option %q must be of type boolean", key) return fmt.Errorf("option %q must be of type boolean", key)
@ -906,7 +906,7 @@ func DefaultOptions() Options {
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low") // ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
type ThinkValue struct { type ThinkValue struct {
// Value can be a bool or string // Value can be a bool or string
Value interface{} Value any
} }
// IsValid checks if the ThinkValue is valid // IsValid checks if the ThinkValue is valid
@ -999,7 +999,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
return nil return nil
} }
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)") return errors.New("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
} }
// MarshalJSON implements json.Marshaler // MarshalJSON implements json.Marshaler
@ -1018,7 +1018,7 @@ func (d Duration) MarshalJSON() ([]byte, error) {
if d.Duration < 0 { if d.Duration < 0 {
return []byte("-1"), nil return []byte("-1"), nil
} }
return []byte("\"" + d.Duration.String() + "\""), nil return []byte("\"" + d.String() + "\""), nil
} }
func (d *Duration) UnmarshalJSON(b []byte) (err error) { func (d *Duration) UnmarshalJSON(b []byte) (err error) {
@ -1045,7 +1045,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
d.Duration = time.Duration(math.MaxInt64) d.Duration = time.Duration(math.MaxInt64)
} }
default: default:
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v)) return fmt.Errorf("unsupported type: '%s'", reflect.TypeOf(v))
} }
return nil return nil
@ -1055,7 +1055,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
func FormatParams(params map[string][]string) (map[string]any, error) { func FormatParams(params map[string][]string) (map[string]any, error) {
opts := Options{} opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
// build map of json struct tags to their types // build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField) jsonOpts := make(map[string]reflect.StructField)
@ -1102,8 +1102,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
// TODO: only string slices are supported right now // TODO: only string slices are supported right now
out[key] = vals out[key] = vals
case reflect.Pointer: case reflect.Pointer:
var b bool if field.Type() == reflect.TypeFor[*bool]() {
if field.Type() == reflect.TypeOf(&b) {
boolVal, err := strconv.ParseBool(vals[0]) boolVal, err := strconv.ParseBool(vals[0])
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals) return nil, fmt.Errorf("invalid bool value %s", vals)

View File

@ -22,6 +22,7 @@ import (
var ErrCancelled = errors.New("Cancelled") var ErrCancelled = errors.New("Cancelled")
// Cancelled refers to ErrCancelled. // Cancelled refers to ErrCancelled.
//
// Deprecated: Use ErrCancelled instead. // Deprecated: Use ErrCancelled instead.
var Cancelled = ErrCancelled var Cancelled = ErrCancelled
@ -37,7 +38,7 @@ type MsgBuilder struct {
} }
// Message initialises a MsgBuilder with the provided message. // Message initialises a MsgBuilder with the provided message.
func Message(format string, args ...interface{}) *MsgBuilder { func Message(format string, args ...any) *MsgBuilder {
return &MsgBuilder{Msg: fmt.Sprintf(format, args...)} return &MsgBuilder{Msg: fmt.Sprintf(format, args...)}
} }

View File

@ -319,7 +319,7 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, fmt.Errorf("timeout scanning server log for inference compute details") return nil, errors.New("timeout scanning server log for inference compute details")
default: default:
} }
file, err := os.Open(serverLogPath) file, err := os.Open(serverLogPath)
@ -345,11 +345,9 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
slog.Info("Matched", "inference compute", ic) slog.Info("Matched", "inference compute", ic)
inference = append(inference, ic) inference = append(inference, ic)
} else { } else if len(inference) > 0 {
// Break out on first non matching line after we start matching // Break out on first non matching line after we start matching
if len(inference) > 0 { return inference, nil
return inference, nil
}
} }
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)

View File

@ -31,7 +31,7 @@ func terminate(proc *os.Process) error {
func terminated(pid int) (bool, error) { func terminated(pid int) (bool, error) {
proc, err := os.FindProcess(pid) proc, err := os.FindProcess(pid)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to find process: %v", err) return false, fmt.Errorf("failed to find process: %w", err)
} }
err = proc.Signal(syscall.Signal(0)) err = proc.Signal(syscall.Signal(0))
@ -40,7 +40,7 @@ func terminated(pid int) (bool, error) {
return true, nil return true, nil
} }
return false, fmt.Errorf("error signaling process: %v", err) return false, fmt.Errorf("error signaling process: %w", err)
} }
return false, nil return false, nil
@ -67,8 +67,7 @@ func reapServers() error {
return nil return nil
} }
pids := strings.Split(pidsStr, "\n") for pidStr := range strings.SplitSeq(pidsStr, "\n") {
for _, pidStr := range pids {
pidStr = strings.TrimSpace(pidStr) pidStr = strings.TrimSpace(pidStr)
if pidStr == "" { if pidStr == "" {
continue continue

View File

@ -5,6 +5,7 @@ package store
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@ -482,7 +483,8 @@ func (db *database) cleanupOrphanedData() error {
} }
func duplicateColumnError(err error) bool { func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok { var sqlite3Err sqlite3.Error
if errors.As(err, &sqlite3Err) {
return sqlite3Err.Code == sqlite3.ErrError && return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name") strings.Contains(sqlite3Err.Error(), "duplicate column name")
} }
@ -490,7 +492,8 @@ func duplicateColumnError(err error) bool {
} }
func columnNotExists(err error) bool { func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok { var sqlite3Err sqlite3.Error
if errors.As(err, &sqlite3Err) {
return sqlite3Err.Code == sqlite3.ErrError && return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column") strings.Contains(sqlite3Err.Error(), "no such column")
} }
@ -586,8 +589,8 @@ func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Cha
&browserState, &browserState,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if errors.Is(err, sql.ErrNoRows) {
return nil, fmt.Errorf("chat not found") return nil, errors.New("chat not found")
} }
return nil, fmt.Errorf("query chat: %w", err) return nil, fmt.Errorf("query chat: %w", err)
} }
@ -752,7 +755,7 @@ func (db *database) updateLastMessage(chatID string, msg Message) error {
return fmt.Errorf("get rows affected: %w", err) return fmt.Errorf("get rows affected: %w", err)
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return fmt.Errorf("no message found to update") return errors.New("no message found to update")
} }
_, err = tx.Exec("DELETE FROM attachments WHERE message_id = ?", messageID) _, err = tx.Exec("DELETE FROM attachments WHERE message_id = ?", messageID)

View File

@ -282,7 +282,7 @@ func countRows(t *testing.T, db *database, table string) int {
return count return count
} }
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...interface{}) int { func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...any) int {
t.Helper() t.Helper()
var count int var count int
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition) query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition)
@ -296,7 +296,7 @@ func countRowsWithCondition(t *testing.T, db *database, table, condition string,
// Test helpers for schema migration testing // Test helpers for schema migration testing
// schemaMap returns both tables/columns and indexes (ignoring order) // schemaMap returns both tables/columns and indexes (ignoring order)
func schemaMap(db *database) map[string]interface{} { func schemaMap(db *database) map[string]any {
result := make(map[string]any) result := make(map[string]any)
result["tables"] = columnMap(db) result["tables"] = columnMap(db)

View File

@ -5,6 +5,7 @@ package store
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -26,7 +27,7 @@ func (i *Image) Bytes() ([]byte, error) {
// ImgBytes reads image data from the specified file path // ImgBytes reads image data from the specified file path
func ImgBytes(path string) ([]byte, error) { func ImgBytes(path string) ([]byte, error) {
if path == "" { if path == "" {
return nil, fmt.Errorf("empty image path") return nil, errors.New("empty image path")
} }
data, err := os.ReadFile(path) data, err := os.ReadFile(path)

View File

@ -4,6 +4,7 @@ package tools
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/url" "net/url"
"regexp" "regexp"
@ -130,7 +131,7 @@ func (b *BrowserSearch) Schema() map[string]any {
func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) { func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
query, ok := args["query"].(string) query, ok := args["query"].(string)
if !ok { if !ok {
return nil, "", fmt.Errorf("query parameter is required") return nil, "", errors.New("query parameter is required")
} }
topn, ok := args["topn"].(int) topn, ok := args["topn"].(int)
@ -150,7 +151,7 @@ func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any,
searchResponse, ok := result.(*WebSearchResponse) searchResponse, ok := result.(*WebSearchResponse)
if !ok { if !ok {
return nil, "", fmt.Errorf("invalid search results format") return nil, "", errors.New("invalid search results format")
} }
// Build main search results page that contains all search results // Build main search results page that contains all search results
@ -383,15 +384,9 @@ func wrapLines(text string, width int) []string {
wrapped = append(wrapped, "") wrapped = append(wrapped, "")
} else if len(line) <= width { } else if len(line) <= width {
wrapped = append(wrapped, line) wrapped = append(wrapped, line)
} else if words := strings.Fields(line); len(words) == 0 {
wrapped = append(wrapped, line)
} else { } else {
// Word wrapping while preserving whitespace structure
words := strings.Fields(line)
if len(words) == 0 {
// Line with only whitespace
wrapped = append(wrapped, line)
continue
}
currentLine := "" currentLine := ""
for _, word := range words { for _, word := range words {
// Check if adding this word would exceed width // Check if adding this word would exceed width
@ -536,15 +531,13 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
if err != nil { if err != nil {
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err) return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
} }
} else { } else if len(b.state.Data.PageStack) != 0 {
// get last page // get last page
if len(b.state.Data.PageStack) != 0 { pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1] var err error
var err error page, err = b.getPageFromStack(pageURL)
page, err = b.getPageFromStack(pageURL) if err != nil {
if err != nil { return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
}
} }
} }
@ -594,7 +587,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
// Try to get id as integer (link ID from current page) // Try to get id as integer (link ID from current page)
if id, ok := args["id"].(float64); ok { if id, ok := args["id"].(float64); ok {
if page == nil { if page == nil {
return nil, "", fmt.Errorf("no current page to resolve link from") return nil, "", errors.New("no current page to resolve link from")
} }
idInt := int(id) idInt := int(id)
pageURL, ok := page.Links[idInt] pageURL, ok := page.Links[idInt]
@ -637,7 +630,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
// If no id provided, just display current page // If no id provided, just display current page
if page == nil { if page == nil {
return nil, "", fmt.Errorf("no current page to display") return nil, "", errors.New("no current page to display")
} }
// Only add to PageStack without updating URLToPage // Only add to PageStack without updating URLToPage
b.state.Data.PageStack = append(b.state.Data.PageStack, page.URL) b.state.Data.PageStack = append(b.state.Data.PageStack, page.URL)
@ -742,7 +735,7 @@ func (b *BrowserFind) Schema() map[string]any {
func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, string, error) { func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, string, error) {
pattern, ok := args["pattern"].(string) pattern, ok := args["pattern"].(string)
if !ok { if !ok {
return nil, "", fmt.Errorf("pattern parameter is required") return nil, "", errors.New("pattern parameter is required")
} }
// Get cursor parameter if provided, default to current page // Get cursor parameter if provided, default to current page
@ -756,7 +749,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
if cursor == -1 { if cursor == -1 {
// Use current page // Use current page
if len(b.state.Data.PageStack) == 0 { if len(b.state.Data.PageStack) == 0 {
return nil, "", fmt.Errorf("no pages to search in") return nil, "", errors.New("no pages to search in")
} }
var err error var err error
page, err = b.getPageFromStack(b.state.Data.PageStack[len(b.state.Data.PageStack)-1]) page, err = b.getPageFromStack(b.state.Data.PageStack[len(b.state.Data.PageStack)-1])
@ -776,7 +769,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
} }
if page == nil { if page == nil {
return nil, "", fmt.Errorf("page not found") return nil, "", errors.New("page not found")
} }
// Create find results page // Create find results page

View File

@ -5,6 +5,7 @@ package tools
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
) )
@ -87,7 +88,7 @@ func (g *BrowserCrawler) Schema() map[string]any {
func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) { func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) {
urlsRaw, ok := args["urls"].([]any) urlsRaw, ok := args["urls"].([]any)
if !ok { if !ok {
return nil, fmt.Errorf("urls parameter is required and must be an array of strings") return nil, errors.New("urls parameter is required and must be an array of strings")
} }
urls := make([]string, 0, len(urlsRaw)) urls := make([]string, 0, len(urlsRaw))
@ -98,7 +99,7 @@ func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*Cra
} }
if len(urls) == 0 { if len(urls) == 0 {
return nil, fmt.Errorf("at least one URL is required") return nil, errors.New("at least one URL is required")
} }
return g.performWebCrawl(ctx, urls) return g.performWebCrawl(ctx, urls)

View File

@ -5,6 +5,7 @@ package tools
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"time" "time"
@ -84,7 +85,7 @@ func (w *BrowserWebSearch) Schema() map[string]any {
func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) { func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) {
queriesRaw, ok := args["queries"].([]any) queriesRaw, ok := args["queries"].([]any)
if !ok { if !ok {
return nil, fmt.Errorf("queries parameter is required and must be an array of strings") return nil, errors.New("queries parameter is required and must be an array of strings")
} }
queries := make([]string, 0, len(queriesRaw)) queries := make([]string, 0, len(queriesRaw))
@ -95,7 +96,7 @@ func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (an
} }
if len(queries) == 0 { if len(queries) == 0 {
return nil, fmt.Errorf("at least one query is required") return nil, errors.New("at least one query is required")
} }
maxResults := 5 maxResults := 5

View File

@ -6,6 +6,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -36,7 +37,7 @@ func (w *WebFetch) Description() string {
return "Crawl and extract text content from web pages" return "Crawl and extract text content from web pages"
} }
func (g *WebFetch) Schema() map[string]any { func (w *WebFetch) Schema() map[string]any {
schemaBytes := []byte(`{ schemaBytes := []byte(`{
"type": "object", "type": "object",
"properties": { "properties": {
@ -61,11 +62,11 @@ func (w *WebFetch) Prompt() string {
func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, string, error) { func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
urlRaw, ok := args["url"] urlRaw, ok := args["url"]
if !ok { if !ok {
return nil, "", fmt.Errorf("url parameter is required") return nil, "", errors.New("url parameter is required")
} }
urlStr, ok := urlRaw.(string) urlStr, ok := urlRaw.(string)
if !ok || strings.TrimSpace(urlStr) == "" { if !ok || strings.TrimSpace(urlStr) == "" {
return nil, "", fmt.Errorf("url must be a non-empty string") return nil, "", errors.New("url must be a non-empty string")
} }
result, err := performWebFetch(ctx, urlStr) result, err := performWebFetch(ctx, urlStr)

View File

@ -6,6 +6,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -45,7 +46,7 @@ func (w *WebSearch) Prompt() string {
return "" return ""
} }
func (g *WebSearch) Schema() map[string]any { func (w *WebSearch) Schema() map[string]any {
schemaBytes := []byte(`{ schemaBytes := []byte(`{
"type": "object", "type": "object",
"properties": { "properties": {
@ -71,12 +72,12 @@ func (g *WebSearch) Schema() map[string]any {
func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) { func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
rawQuery, ok := args["query"] rawQuery, ok := args["query"]
if !ok { if !ok {
return nil, "", fmt.Errorf("query parameter is required") return nil, "", errors.New("query parameter is required")
} }
queryStr, ok := rawQuery.(string) queryStr, ok := rawQuery.(string)
if !ok || strings.TrimSpace(queryStr) == "" { if !ok || strings.TrimSpace(queryStr) == "" {
return nil, "", fmt.Errorf("query must be a non-empty string") return nil, "", errors.New("query must be a non-empty string")
} }
maxResults := 5 maxResults := 5

View File

@ -19,10 +19,12 @@ import (
// Errors wrapping Found should provide additional context, e.g. // Errors wrapping Found should provide additional context, e.g.
// fmt.Errorf("%w: %s", not.Found, key) // fmt.Errorf("%w: %s", not.Found, key)
// //
//nolint:staticcheck
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Found. //lint:ignore ST1012 This is a sentinel error intended to be read like not.Found.
var Found = errors.New("not found") var Found = errors.New("not found")
// Available is an error that indicates that a value is not available. // Available is an error that indicates that a value is not available.
// //
//nolint:staticcheck
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Available. //lint:ignore ST1012 This is a sentinel error intended to be read like not.Available.
var Available = errors.New("not available") var Available = errors.New("not available")

View File

@ -4,6 +4,7 @@ package not
import ( import (
"fmt" "fmt"
"strings"
) )
type ValidError struct { type ValidError struct {
@ -44,12 +45,12 @@ func (b Valids) Error() string {
return "" return ""
} }
var result string var sb strings.Builder
for i, err := range b { for i, err := range b {
if i > 0 { if i > 0 {
result += "; " sb.WriteString("; ")
} }
result += err.Error() sb.WriteString(err.Error())
} }
return result return sb.String()
} }

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -73,7 +74,7 @@ func extractPDFText(data []byte) (string, error) {
if strings.TrimSpace(text) != "" { if strings.TrimSpace(text) != "" {
if textBuilder.Len() > 0 { if textBuilder.Len() > 0 {
textBuilder.WriteString("\n\n--- Page ") textBuilder.WriteString("\n\n--- Page ")
textBuilder.WriteString(fmt.Sprintf("%d", i)) textBuilder.WriteString(strconv.Itoa(i))
textBuilder.WriteString(" ---\n") textBuilder.WriteString(" ---\n")
} }
textBuilder.WriteString(text) textBuilder.WriteString(text)

View File

@ -194,7 +194,7 @@ func (s *Server) Handler() http.Handler {
log := s.log() log := s.log()
level := slog.LevelInfo level := slog.LevelInfo
start := time.Now() start := time.Now()
requestID := fmt.Sprintf("%d", time.Now().UnixNano()) requestID := strconv.FormatInt(time.Now().UnixNano(), 10)
defer func() { defer func() {
p := recover() p := recover()
@ -204,7 +204,7 @@ func (s *Server) Handler() http.Handler {
// Handle panic with user-friendly error // Handle panic with user-friendly error
if !sw.Written() { if !sw.Written() {
s.handleError(sw, fmt.Errorf("internal server error")) s.handleError(sw, errors.New("internal server error"))
} }
} }
@ -382,7 +382,7 @@ func waitForServer(ctx context.Context) error {
break break
} }
if time.Now().After(timeout) { if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready") return errors.New("timeout waiting for Ollama server to be ready")
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
} }
@ -455,7 +455,7 @@ func (s *Server) checkModelUpstream(ctx context.Context, modelName string, timeo
digest := resp.Header.Get("ollama-content-digest") digest := resp.Header.Get("ollama-content-digest")
if digest == "" { if digest == "" {
return "", 0, fmt.Errorf("no digest header found") return "", 0, errors.New("no digest header found")
} }
var pushTime int64 var pushTime int64
@ -598,12 +598,12 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
} }
if req.Model == "" { if req.Model == "" {
return fmt.Errorf("empty model") return errors.New("empty model")
} }
// Don't allow empty messages unless forceUpdate is true // Don't allow empty messages unless forceUpdate is true
if req.Prompt == "" && !req.ForceUpdate { if req.Prompt == "" && !req.ForceUpdate {
return fmt.Errorf("empty message") return errors.New("empty message")
} }
if createdChat { if createdChat {
@ -942,7 +942,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
} else { } else {
onlyStandalone := true onlyStandalone := true
for _, tc := range res.Message.ToolCalls { for _, tc := range res.Message.ToolCalls {
if !(tc.Function.Name == "web_search" || tc.Function.Name == "web_fetch") { if tc.Function.Name != "web_search" && tc.Function.Name != "web_fetch" {
onlyStandalone = false onlyStandalone = false
break break
} }
@ -1194,7 +1194,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id") cid := r.PathValue("id")
if cid == "" { if cid == "" {
return fmt.Errorf("chat ID is required") return errors.New("chat ID is required")
} }
chat, err := s.Store.Chat(cid) chat, err := s.Store.Chat(cid)
@ -1252,7 +1252,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error { func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id") cid := r.PathValue("id")
if cid == "" { if cid == "" {
return fmt.Errorf("chat ID is required") return errors.New("chat ID is required")
} }
var req struct { var req struct {
@ -1283,7 +1283,7 @@ func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error { func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id") cid := r.PathValue("id")
if cid == "" { if cid == "" {
return fmt.Errorf("chat ID is required") return errors.New("chat ID is required")
} }
// Check if the chat exists (no need to load attachments) // Check if the chat exists (no need to load attachments)
@ -1291,7 +1291,7 @@ func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
if err != nil { if err != nil {
if errors.Is(err, not.Found) { if errors.Is(err, not.Found) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
return fmt.Errorf("chat not found") return errors.New("chat not found")
} }
return fmt.Errorf("failed to get chat: %w", err) return fmt.Errorf("failed to get chat: %w", err)
} }
@ -1592,7 +1592,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error { func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" { if r.Method != "POST" {
return fmt.Errorf("method not allowed") return errors.New("method not allowed")
} }
var req struct { var req struct {
@ -1603,7 +1603,7 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
} }
if req.Model == "" { if req.Model == "" {
return fmt.Errorf("model is required") return errors.New("model is required")
} }
digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second) digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second)
@ -1730,8 +1730,8 @@ func supportsWebSearchTools(model string) bool {
// buildChatRequest converts store.Chat to api.ChatRequest // buildChatRequest converts store.Chat to api.ChatRequest
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) { func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
var msgs []api.Message msgs := make([]api.Message, len(chat.Messages))
for _, m := range chat.Messages { for i, m := range chat.Messages {
// Skip empty messages if present // Skip empty messages if present
if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 { if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 {
continue continue
@ -1789,7 +1789,7 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
s.log().Debug("unknown message role", "role", m.Role) s.log().Debug("unknown message role", "role", m.Role)
} }
msgs = append(msgs, apiMsg) msgs[i] = apiMsg
} }
var thinkValue *api.ThinkValue var thinkValue *api.ThinkValue

View File

@ -198,7 +198,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(filepath.Dir(stageFilename)) _, err = os.Stat(filepath.Dir(stageFilename))
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil { if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil {
return fmt.Errorf("create ollama dir %s: %v", filepath.Dir(stageFilename), err) return fmt.Errorf("create ollama dir %s: %w", filepath.Dir(stageFilename), err)
} }
} }
@ -218,7 +218,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
if err := VerifyDownload(); err != nil { if err := VerifyDownload(); err != nil {
_ = os.Remove(stageFilename) _ = os.Remove(stageFilename)
return fmt.Errorf("%s - %s", resp.Request.URL.String(), err) return fmt.Errorf("%s - %w", resp.Request.URL.String(), err)
} }
UpdateDownloaded = true UpdateDownloaded = true
return nil return nil

View File

@ -92,7 +92,7 @@ func DoUpgrade(interactive bool) error {
bundle := getStagedUpdate() bundle := getStagedUpdate()
if bundle == "" { if bundle == "" {
return fmt.Errorf("failed to lookup downloads") return errors.New("failed to lookup downloads")
} }
slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile) slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile)
@ -107,7 +107,7 @@ func DoUpgrade(interactive bool) error {
// Verify old doesn't exist yet // Verify old doesn't exist yet
if _, err := os.Stat(contentsOldName); err == nil { if _, err := os.Stat(contentsOldName); err == nil {
slog.Error("prior upgrade failed", "backup", contentsOldName) slog.Error("prior upgrade failed", "backup", contentsOldName)
return fmt.Errorf("prior upgrade failed - please upgrade manually by installing the bundle") return errors.New("prior upgrade failed - please upgrade manually by installing the bundle")
} }
if err := os.MkdirAll(appBackupDir, 0o755); err != nil { if err := os.MkdirAll(appBackupDir, 0o755); err != nil {
return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err) return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err)
@ -133,7 +133,7 @@ func DoUpgrade(interactive bool) error {
return err return err
} }
if !chownWithAuthorization(u.Username) { if !chownWithAuthorization(u.Username) {
return fmt.Errorf("unable to change permissions to complete upgrade") return errors.New("unable to change permissions to complete upgrade")
} }
if err := os.Rename(BundlePath, appBackup); err != nil { if err := os.Rename(BundlePath, appBackup); err != nil {
return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err) return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err)
@ -264,7 +264,7 @@ func DoPostUpgradeCleanup() error {
func verifyDownload() error { func verifyDownload() error {
bundle := getStagedUpdate() bundle := getStagedUpdate()
if bundle == "" { if bundle == "" {
return fmt.Errorf("failed to lookup downloads") return errors.New("failed to lookup downloads")
} }
slog.Debug("verifying update", "bundle", bundle) slog.Debug("verifying update", "bundle", bundle)
@ -338,7 +338,7 @@ func verifyDownload() error {
} }
if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil { if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil {
return fmt.Errorf("signature verification failed: %s", err) return fmt.Errorf("signature verification failed: %w", err)
} }
return nil return nil
} }
@ -347,11 +347,11 @@ func verifyDownload() error {
func DoUpgradeAtStartup() error { func DoUpgradeAtStartup() error {
bundle := getStagedUpdate() bundle := getStagedUpdate()
if bundle == "" { if bundle == "" {
return fmt.Errorf("failed to lookup downloads") return errors.New("failed to lookup downloads")
} }
if BundlePath == "" { if BundlePath == "" {
return fmt.Errorf("unable to upgrade at startup, app in development mode") return errors.New("unable to upgrade at startup, app in development mode")
} }
// [Re]verify before proceeding // [Re]verify before proceeding

View File

@ -22,9 +22,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
var server *httptest.Server var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" { if r.URL.Path == "/update.json" {
w.Write([]byte( fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
// TODO - wire up the redirects to mimic real behavior // TODO - wire up the redirects to mimic real behavior
} else { } else {
slog.Debug("unexpected request", "url", r.URL) slog.Debug("unexpected request", "url", r.URL)
@ -67,17 +65,16 @@ func TestBackgoundChecker(t *testing.T) {
var server *httptest.Server var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" { switch r.URL.Path {
w.Write([]byte( case "/update.json":
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`, fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
server.URL+"/9.9.9/"+Installer)))
// TODO - wire up the redirects to mimic real behavior // TODO - wire up the redirects to mimic real behavior
} else if r.URL.Path == "/9.9.9/"+Installer { case "/9.9.9/" + Installer:
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
zw := zip.NewWriter(buf) zw := zip.NewWriter(buf)
zw.Close() zw.Close()
io.Copy(w, buf) io.Copy(w, buf)
} else { default:
slog.Debug("unexpected request", "url", r.URL) slog.Debug("unexpected request", "url", r.URL)
} }
})) }))

View File

@ -149,7 +149,7 @@ func BenchmarkChat(fOpt flagOptions) error {
for _, model := range models { for _, model := range models {
for range *fOpt.epochs { for range *fOpt.epochs {
options := make(map[string]interface{}) options := make(map[string]any)
if *fOpt.maxTokens > 0 { if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens options["num_predict"] = *fOpt.maxTokens
} }

View File

@ -442,7 +442,7 @@ func TestReadImage_FileNotFound(t *testing.T) {
func TestOptionsMapCreation(t *testing.T) { func TestOptionsMapCreation(t *testing.T) {
fOpt := createTestFlagOptions() fOpt := createTestFlagOptions()
options := make(map[string]interface{}) options := make(map[string]any)
if *fOpt.maxTokens > 0 { if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens options["num_predict"] = *fOpt.maxTokens
} }

View File

@ -11,6 +11,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"maps"
"math" "math"
"net" "net"
"net/http" "net/http"
@ -203,7 +204,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
if err := client.Create(cmd.Context(), req, fn); err != nil { if err := client.Create(cmd.Context(), req, fn); err != nil {
if strings.Contains(err.Error(), "path or Modelfile are required") { if strings.Contains(err.Error(), "path or Modelfile are required") {
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client") return errors.New("the ollama server must be updated to use `ollama create` with this client")
} }
return err return err
} }
@ -990,7 +991,7 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
var v string var v string
switch vData := resp.ModelInfo[k].(type) { switch vData := resp.ModelInfo[k].(type) {
case bool: case bool:
v = fmt.Sprintf("%t", vData) v = strconv.FormatBool(vData)
case string: case string:
v = vData v = vData
case float64: case float64:
@ -1204,9 +1205,7 @@ func (r runOptions) Copy() runOptions {
var opts map[string]any var opts map[string]any
if r.Options != nil { if r.Options != nil {
opts = make(map[string]any, len(r.Options)) opts = make(map[string]any, len(r.Options))
for k, v := range r.Options { maps.Copy(opts, r.Options)
opts[k] = v
}
} }
var think *api.ThinkValue var think *api.ThinkValue
@ -1330,12 +1329,12 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
cancel() cancel()
}() }()
var state *displayResponseState = &displayResponseState{} state := &displayResponseState{}
var thinkingContent strings.Builder var thinkingContent strings.Builder
var latest api.ChatResponse var latest api.ChatResponse
var fullResponse strings.Builder var fullResponse strings.Builder
var thinkTagOpened bool = false thinkTagOpened := false
var thinkTagClosed bool = false thinkTagClosed := false
role := "assistant" role := "assistant"
@ -1463,10 +1462,10 @@ func generate(cmd *cobra.Command, opts runOptions) error {
cancel() cancel()
}() }()
var state *displayResponseState = &displayResponseState{} state := &displayResponseState{}
var thinkingContent strings.Builder var thinkingContent strings.Builder
var thinkTagOpened bool = false thinkTagOpened := false
var thinkTagClosed bool = false thinkTagClosed := false
plainText := !term.IsTerminal(int(os.Stdout.Fd())) plainText := !term.IsTerminal(int(os.Stdout.Fd()))
@ -1634,7 +1633,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err return err
} }
if err := client.Heartbeat(cmd.Context()); err != nil { if err := client.Heartbeat(cmd.Context()); err != nil {
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) { if !strings.Contains(err.Error(), " refused") && !strings.Contains(err.Error(), "could not connect") {
return err return err
} }
if err := startApp(cmd.Context(), client); err != nil { if err := startApp(cmd.Context(), client); err != nil {
@ -1952,13 +1951,13 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
} }
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string { func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
out := "" var sb strings.Builder
formatExplanation := "" formatExplanation := ""
formatValues := "" formatValues := ""
if !plainText { if !plainText {
formatExplanation = readline.ColorGrey + readline.ColorBold formatExplanation = readline.ColorGrey + readline.ColorBold
formatValues = readline.ColorDefault formatValues = readline.ColorDefault
out += formatExplanation sb.WriteString(formatExplanation)
} }
for i, toolCall := range toolCalls { for i, toolCall := range toolCalls {
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments) argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
@ -1966,13 +1965,13 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
return "" return ""
} }
if i > 0 { if i > 0 {
out += "\n" sb.WriteString("\n")
} }
// all tool calls are unexpected since we don't currently support registering any in the CLI // all tool calls are unexpected since we don't currently support registering any in the CLI
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation) fmt.Fprintf(&sb, " Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
} }
if !plainText { if !plainText {
out += readline.ColorDefault sb.WriteString(readline.ColorDefault)
} }
return out return sb.String()
} }

View File

@ -3,6 +3,7 @@ package cmd
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -307,7 +308,7 @@ func TestDeleteHandler(t *testing.T) {
} else { } else {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
errPayload := `{"error":"model '%s' not found"}` errPayload := `{"error":"model '%s' not found"}`
w.Write([]byte(fmt.Sprintf(errPayload, req.Name))) fmt.Fprintf(w, errPayload, req.Name)
} }
return return
} }
@ -761,8 +762,8 @@ func TestGetModelfileName(t *testing.T) {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename) t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
} }
if tt.expectedErr != os.ErrNotExist { if !errors.Is(tt.expectedErr, os.ErrNotExist) {
if actualErr != tt.expectedErr { if !errors.Is(actualErr, tt.expectedErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
} }
} else { } else {
@ -924,10 +925,8 @@ func TestPushHandler(t *testing.T) {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got) t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
} }
} }
} else { } else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) { t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
} }
}) })
} }
@ -1014,10 +1013,8 @@ func TestListHandler(t *testing.T) {
if got := string(output); got != tt.expectedOutput { if got := string(output); got != tt.expectedOutput {
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got) t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
} }
} else { } else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) { t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
} }
}) })
} }
@ -1322,8 +1319,8 @@ func TestRunOptions_Copy(t *testing.T) {
// Test 2: Verify all fields are copied correctly // Test 2: Verify all fields are copied correctly
tests := []struct { tests := []struct {
name string name string
got interface{} got any
want interface{} want any
}{ }{
{"Model", copied.Model, original.Model}, {"Model", copied.Model, original.Model},
{"ParentModel", copied.ParentModel, original.ParentModel}, {"ParentModel", copied.ParentModel, original.ParentModel},

View File

@ -130,7 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil thinkExplicitlySet := opts.Think != nil
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@ -410,7 +410,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if resp.Parameters == "" { if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified for this model.") fmt.Println(" No additional parameters were specified for this model.")
} else { } else {
for _, l := range strings.Split(resp.Parameters, "\n") { for l := range strings.SplitSeq(resp.Parameters, "\n") {
fmt.Printf(" %s\n", l) fmt.Printf(" %s\n", l)
} }
} }
@ -576,9 +576,8 @@ func extractFileNames(input string) []string {
func extractFileData(input string) (string, []api.ImageData, error) { func extractFileData(input string) (string, []api.ImageData, error) {
filePaths := extractFileNames(input) filePaths := extractFileNames(input)
var imgs []api.ImageData imgs := make([]api.ImageData, len(filePaths))
for i, fp := range filePaths {
for _, fp := range filePaths {
nfp := normalizeFilePath(fp) nfp := normalizeFilePath(fp)
data, err := getImageData(nfp) data, err := getImageData(nfp)
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
@ -591,7 +590,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
input = strings.ReplaceAll(input, "'"+nfp+"'", "") input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "") input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data) imgs[i] = data
} }
return strings.TrimSpace(input), imgs, nil return strings.TrimSpace(input), imgs, nil
} }

View File

@ -38,10 +38,10 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
"general.file_type": uint32(1), "general.file_type": uint32(1),
"general.quantization_version": uint32(2), "general.quantization_version": uint32(2),
"tokenizer.ggml.pre": t.Pre, "tokenizer.ggml.pre": t.Pre,
"tokenizer.ggml.model": t.Vocabulary.Model, "tokenizer.ggml.model": t.Model,
"tokenizer.ggml.tokens": t.Vocabulary.Tokens, "tokenizer.ggml.tokens": t.Tokens,
"tokenizer.ggml.scores": t.Vocabulary.Scores, "tokenizer.ggml.scores": t.Scores,
"tokenizer.ggml.token_type": t.Vocabulary.Types, "tokenizer.ggml.token_type": t.Types,
} }
if len(t.Merges) > 0 { if len(t.Merges) > 0 {
@ -231,20 +231,20 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch { switch {
case vocabSize == 0: case vocabSize == 0:
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Tokens))
case vocabSize > len(t.Vocabulary.Tokens): case vocabSize > len(t.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) t.Tokens = append(t.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) t.Scores = append(t.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) t.Types = append(t.Types, tokenTypeUserDefined)
} }
case vocabSize < len(t.Vocabulary.Tokens): case vocabSize < len(t.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Tokens))
p.VocabSize = uint32(len(t.Vocabulary.Tokens)) p.VocabSize = uint32(len(t.Tokens))
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens)) p.TextModel.VocabSize = uint32(len(t.Tokens))
default: default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary", "size", len(t.Tokens))
} }
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...)) ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))

View File

@ -137,7 +137,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
} }
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts { for _, t := range ts {
if slices.Contains([]string{ if slices.Contains([]string{
"embeddings.position_ids", "embeddings.position_ids",

View File

@ -44,14 +44,14 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
} }
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, len(ts))
for _, t := range ts { for i, t := range ts {
out = append(out, &ggml.Tensor{ out[i] = &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
}) }
} }
return out return out

View File

@ -43,18 +43,18 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
} }
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, len(ts))
for _, t := range ts { for i, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") { if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne) t.SetRepacker(p.addOne)
} }
out = append(out, &ggml.Tensor{ out[i] = &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
}) }
} }
return out return out

View File

@ -22,8 +22,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
} }
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor { func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, len(ts))
for _, t := range ts { for i, t := range ts {
shape := t.Shape() shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) || if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) { (strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
@ -31,12 +31,12 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
out = append(out, &ggml.Tensor{ out[i] = &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
}) }
} }
return out return out

View File

@ -111,7 +111,7 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s { for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape() dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") { if !strings.HasSuffix(name, ".weight") {
name = name + ".weight" name += ".weight"
} }
if strings.Contains(name, "ffn_down_exps") { if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{

View File

@ -127,7 +127,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
} }
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, 0, len(ts)+1)
if p.RopeScaling.factors != nil { if p.RopeScaling.factors != nil {
out = append(out, &ggml.Tensor{ out = append(out, &ggml.Tensor{
@ -176,9 +176,9 @@ func (p *llamaModel) Replacements() []string {
} }
func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) { func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int dims := make([]int, len(shape))
for _, dim := range shape { for i, dim := range shape {
dims = append(dims, int(dim)) dims[i] = int(dim)
} }
var heads uint32 var heads uint32

View File

@ -30,8 +30,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
} }
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor { func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, len(ts))
for _, t := range ts { for i, t := range ts {
shape := t.Shape() shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) || if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) { (strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
@ -41,12 +41,12 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
out = append(out, &ggml.Tensor{ out[i] = &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: shape, Shape: shape,
WriterTo: t, WriterTo: t,
}) }
} }
return out return out

View File

@ -90,9 +90,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
} }
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor { func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, len(ts))
for i, t := range ts {
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") { if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") || if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") { strings.HasSuffix(t.Name(), ".attn_k.weight") {
@ -100,12 +99,12 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
} }
} }
out = append(out, &ggml.Tensor{ out[i] = &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
}) }
} }
return out return out
@ -145,9 +144,9 @@ func (p *mistral3Model) Replacements() []string {
} }
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) { func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int dims := make([]int, len(shape))
for _, dim := range shape { for i, dim := range shape {
dims = append(dims, int(dim)) dims[i] = int(dim)
} }
var heads uint32 var heads uint32

View File

@ -49,20 +49,20 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
} }
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor { func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor out := make([]*ggml.Tensor, len(ts))
for _, t := range ts { for i, t := range ts {
out = append(out, &ggml.Tensor{ out[i] = &ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
WriterTo: t, WriterTo: t,
}) }
} }
return out return out
} }
func (p *qwen2Model) Replacements() []string { func (q *qwen2Model) Replacements() []string {
return []string{ return []string{
"lm_head", "output", "lm_head", "output",
"model.embed_tokens", "token_embd", "model.embed_tokens", "token_embd",

View File

@ -90,9 +90,9 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
return out return out
} }
func (p *qwen25VLModel) Replacements() []string { func (q *qwen25VLModel) Replacements() []string {
return append( return append(
p.qwen2Model.Replacements(), q.qwen2Model.Replacements(),
"visual", "v", "visual", "v",
"blocks", "blk", "blocks", "blk",
"attn.proj", "attn_out", "attn.proj", "attn_out",

View File

@ -54,6 +54,6 @@ func (t torch) Clone() Tensor {
} }
} }
func (pt torch) WriteTo(w io.Writer) (int64, error) { func (t torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil return 0, nil
} }

View File

@ -82,7 +82,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
content string content string
} }
var ts []t ts := make([]t, 0, len(atm))
for content, id := range atm { for content, id := range atm {
ts = append(ts, t{id, content}) ts = append(ts, t{id, content})
} }

View File

@ -300,9 +300,9 @@ func (s Tensors) Items(prefix ...string) []*Tensor {
return items return items
} }
func (ts Tensors) GroupLayers() map[string]Layer { func (s Tensors) GroupLayers() map[string]Layer {
layers := make(map[string]Layer) layers := make(map[string]Layer)
for _, t := range ts.items { for _, t := range s.items {
parts := strings.Split(t.Name, ".") parts := strings.Split(t.Name, ".")
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 { if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
if len(parts) > index+2 { if len(parts) > index+2 {

View File

@ -5,6 +5,7 @@ import (
"cmp" "cmp"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@ -225,7 +226,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
Name: name, Name: name,
Kind: kind, Kind: kind,
Offset: offset, Offset: offset,
Shape: shape[:], Shape: shape,
} }
llm.tensors = append(llm.tensors, &tensor) llm.tensors = append(llm.tensors, &tensor)
@ -511,7 +512,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
arch := kv.String("general.architecture") arch := kv.String("general.architecture")
if arch == "" { if arch == "" {
return fmt.Errorf("architecture not set") return errors.New("architecture not set")
} }
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil { if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {

View File

@ -136,8 +136,8 @@ func (t FileType) Value() uint32 {
return uint32(t) return uint32(t)
} }
func (ftype FileType) ToTensorType() TensorType { func (t FileType) ToTensorType() TensorType {
switch ftype { switch t {
case FileTypeF32: case FileTypeF32:
return TensorTypeF32 return TensorTypeF32
case FileTypeF16: case FileTypeF16:
@ -177,7 +177,7 @@ func (ftype FileType) ToTensorType() TensorType {
case fileTypeMXFP4: case fileTypeMXFP4:
return TensorTypeMXFP4 return TensorTypeMXFP4
default: default:
slog.Warn("unsupported file type", "type", ftype) slog.Warn("unsupported file type", "type", t)
return 0 // F32 return 0 // F32
} }
} }

View File

@ -11,7 +11,7 @@ type KeyValue struct {
} }
func (kv KeyValue) Valid() bool { func (kv KeyValue) Valid() bool {
return kv.Key != "" && kv.Value.value != nil return kv.Key != "" && kv.value != nil
} }
type Value struct { type Value struct {

View File

@ -200,9 +200,7 @@ func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
before := raw[:channelIndex] before := raw[:channelIndex]
after := raw[channelIndex+len("<|channel|>"):] after := raw[channelIndex+len("<|channel|>"):]
// the channel name is `after` all the way up to the first (if any) whitespace character // the channel name is `after` all the way up to the first (if any) whitespace character
idx := strings.IndexFunc(after, func(r rune) bool { idx := strings.IndexFunc(after, unicode.IsSpace)
return unicode.IsSpace(r)
})
if idx == -1 { if idx == -1 {
idx = len(after) idx = len(after)
} }
@ -319,11 +317,12 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
} }
case HarmonyEventContentEmitted: case HarmonyEventContentEmitted:
logutil.Trace("harmony event content", "content", event.Content, "state", h.state) logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
if h.state == harmonyMessageState_Normal { switch h.state {
case harmonyMessageState_Normal:
contentSb.WriteString(event.Content) contentSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_Thinking { case harmonyMessageState_Thinking:
thinkingSb.WriteString(event.Content) thinkingSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_ToolCalling { case harmonyMessageState_ToolCalling:
toolContentSb.WriteString(event.Content) toolContentSb.WriteString(event.Content)
} }
case HarmonyEventMessageEnd: case HarmonyEventMessageEnd:

View File

@ -263,9 +263,9 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.use_mmap = C.bool(params.UseMmap) cparams.use_mmap = C.bool(params.UseMmap)
cparams.vocab_only = C.bool(params.VocabOnly) cparams.vocab_only = C.bool(params.VocabOnly)
var devices []C.ggml_backend_dev_t devices := make([]C.ggml_backend_dev_t, len(params.Devices))
for _, llamaID := range params.Devices { for i, llamaID := range params.Devices {
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID))) devices[i] = C.ggml_backend_dev_get(C.size_t(llamaID))
} }
if len(devices) > 0 { if len(devices) > 0 {
devices = append(devices, C.ggml_backend_dev_t(C.NULL)) devices = append(devices, C.ggml_backend_dev_t(C.NULL))

View File

@ -250,7 +250,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if s.status != nil && s.status.LastErrMsg != "" { if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg msg = s.status.LastErrMsg
} }
err := fmt.Errorf("error starting runner: %v %s", err, msg) err := fmt.Errorf("error starting runner: %w %s", err, msg)
if llamaModel != nil { if llamaModel != nil {
llama.FreeModel(llamaModel) llama.FreeModel(llamaModel)
} }
@ -846,14 +846,7 @@ nextOperation:
func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID { func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
devices := []ml.DeviceID{} devices := []ml.DeviceID{}
for _, layer := range gpuLayers { for _, layer := range gpuLayers {
new := true if !slices.Contains(devices, layer.DeviceID) {
for _, ID := range devices {
if layer.DeviceID == ID {
new = false
break
}
}
if new {
devices = append(devices, layer.DeviceID) devices = append(devices, layer.DeviceID)
} }
} }
@ -989,13 +982,11 @@ nextLayer:
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap)) slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available)) return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
} }
} else { } else if vramSize > systemInfo.TotalMemory {
if vramSize > systemInfo.TotalMemory { // disable partial offloading when model is greater than total system memory as this
// disable partial offloading when model is greater than total system memory as this // can lead to locking up the system
// can lead to locking up the system s.options.NumGPU = 0
s.options.NumGPU = 0 gpuLayers = ml.GPULayersList{}
gpuLayers = ml.GPULayersList{}
}
} }
if gpuLayers.Sum() == 0 { if gpuLayers.Sum() == 0 {
@ -1218,7 +1209,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
if err != nil { if err != nil {
return ServerStatusError, fmt.Errorf("error creating GET request: %v", err) return ServerStatusError, fmt.Errorf("error creating GET request: %w", err)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@ -1481,7 +1472,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// User provided a JSON schema // User provided a JSON schema
g := llama.SchemaToGrammar(req.Format) g := llama.SchemaToGrammar(req.Format)
if g == nil { if g == nil {
return fmt.Errorf("invalid JSON schema in format") return errors.New("invalid JSON schema in format")
} }
req.Grammar = string(g) req.Grammar = string(g)
} }
@ -1521,13 +1512,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
if err := enc.Encode(req); err != nil { if err := enc.Encode(req); err != nil {
return fmt.Errorf("failed to marshal data: %v", err) return fmt.Errorf("failed to marshal data: %w", err)
} }
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port) endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer) serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil { if err != nil {
return fmt.Errorf("error creating POST request: %v", err) return fmt.Errorf("error creating POST request: %w", err)
} }
serverReq.Header.Set("Content-Type", "application/json") serverReq.Header.Set("Content-Type", "application/json")
@ -1576,7 +1567,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
var c CompletionResponse var c CompletionResponse
if err := json.Unmarshal(evt, &c); err != nil { if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err) return fmt.Errorf("error unmarshalling llm prediction response: %w", err)
} }
switch { switch {
case strings.TrimSpace(c.Content) == lastToken: case strings.TrimSpace(c.Content) == lastToken:
@ -1618,7 +1609,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("an error was encountered while running the model: %s", msg) return fmt.Errorf("an error was encountered while running the model: %s", msg)
} }
return fmt.Errorf("error reading llm response: %v", err) return fmt.Errorf("error reading llm response: %w", err)
} }
return nil return nil
@ -1693,7 +1684,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
defer s.llamaModelLock.Unlock() defer s.llamaModelLock.Unlock()
if s.llamaModel == nil { if s.llamaModel == nil {
return nil, fmt.Errorf("no tokenizer configured") return nil, errors.New("no tokenizer configured")
} }
return s.llamaModel.Tokenize(content, false, true) return s.llamaModel.Tokenize(content, false, true)
@ -1718,15 +1709,15 @@ func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
defer s.llamaModelLock.Unlock() defer s.llamaModelLock.Unlock()
if s.llamaModel == nil { if s.llamaModel == nil {
return "", fmt.Errorf("no tokenizer configured") return "", errors.New("no tokenizer configured")
} }
var resp string var sb strings.Builder
for _, token := range tokens { for _, token := range tokens {
resp += s.llamaModel.TokenToPiece(token) sb.WriteString(s.llamaModel.TokenToPiece(token))
} }
return resp, nil return sb.String(), nil
} }
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {

View File

@ -209,7 +209,7 @@ func TestLLMServerFitGPU(t *testing.T) {
} }
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0) gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
if err != tt.expectedErr { if !errors.Is(err, tt.expectedErr) {
t.Fatalf("fitGPU returned error: %v", err) t.Fatalf("fitGPU returned error: %v", err)
} }
if gpuLayers.Hash() != tt.expected.Hash() { if gpuLayers.Hash() != tt.expected.Hash() {

View File

@ -84,7 +84,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
} }
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) _, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -98,7 +98,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) _, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -123,7 +123,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
} }
func (w *ChatWriter) Write(data []byte) (int, error) { func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(data) return w.writeError(data)
} }
@ -150,7 +150,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
} }
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) _, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -164,7 +164,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d))) _, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -189,7 +189,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
} }
func (w *CompleteWriter) Write(data []byte) (int, error) { func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(data) return w.writeError(data)
} }
@ -214,7 +214,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
} }
func (w *ListWriter) Write(data []byte) (int, error) { func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(data) return w.writeError(data)
} }
@ -240,7 +240,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
} }
func (w *RetrieveWriter) Write(data []byte) (int, error) { func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(data) return w.writeError(data)
} }
@ -265,7 +265,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
} }
func (w *EmbedWriter) Write(data []byte) (int, error) { func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status() code := w.Status()
if code != http.StatusOK { if code != http.StatusOK {
return w.writeError(data) return w.writeError(data)
} }

View File

@ -68,7 +68,7 @@ func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) {
switch tc.expectType { switch tc.expectType {
case "array": case "array":
if _, ok := result.Data[0].Embedding.([]interface{}); !ok { if _, ok := result.Data[0].Embedding.([]any); !ok {
t.Errorf("expected array, got %T", result.Data[0].Embedding) t.Errorf("expected array, got %T", result.Data[0].Embedding)
} }
case "string": case "string":
@ -210,10 +210,8 @@ func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) {
if !strings.Contains(errResp.Error.Message, "encoding_format") { if !strings.Contains(errResp.Error.Message, "encoding_format") {
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message) t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
} }
} else { } else if resp.Code != http.StatusOK {
if resp.Code != http.StatusOK { t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
}
} }
}) })
} }

View File

@ -845,19 +845,17 @@ func TestListMiddleware(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
var expected, actual map[string]any var want, got map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected) if err := json.Unmarshal([]byte(tc.resp), &want); err != nil {
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err) t.Fatalf("failed to unmarshal expected response: %v", err)
} }
err = json.Unmarshal(resp.Body.Bytes(), &actual) if err := json.Unmarshal(resp.Body.Bytes(), &got); err != nil {
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err) t.Fatalf("failed to unmarshal actual response: %v", err)
} }
if !reflect.DeepEqual(expected, actual) { if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual) t.Errorf("response does not match (-want +got):\n%s", diff)
} }
} }
} }

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"math" "math"
"slices" "slices"
@ -92,7 +93,7 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
return backend(modelPath, params) return backend(modelPath, params)
} }
return nil, fmt.Errorf("unsupported backend") return nil, errors.New("unsupported backend")
} }
type Context interface { type Context interface {

View File

@ -178,14 +178,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
requiredMemory.CPU.Cache = make([]uint64, blocks+1) requiredMemory.CPU.Cache = make([]uint64, blocks+1)
// create list of buffer types for each gpu // create list of buffer types for each gpu
var gpuDeviceBufferTypes []deviceBufferType gpuDeviceBufferTypes := make([]deviceBufferType, len(gpus))
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus)) requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
for i, d := range gpus { for i, d := range gpus {
bt := C.ggml_backend_dev_buffer_type(d) bt := C.ggml_backend_dev_buffer_type(d)
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{ gpuDeviceBufferTypes[i] = deviceBufferType{
d: d, d: d,
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...), bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
}) }
btDeviceMemory[bt] = &requiredMemory.GPUs[i] btDeviceMemory[bt] = &requiredMemory.GPUs[i]
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
@ -354,8 +354,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t) deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t)
// create backends and buffer types used for the compute graph scheduler // create backends and buffer types used for the compute graph scheduler
var schedBackends []C.ggml_backend_t schedBackends := make([]C.ggml_backend_t, 0, len(cpus)+len(accels)+len(gpus))
var schedBufts []C.ggml_backend_buffer_type_t schedBufts := make([]C.ggml_backend_buffer_type_t, 0, len(cpus)+len(accels)+len(gpus))
for _, d := range append(gpus, append(accels, cpus...)...) { for _, d := range append(gpus, append(accels, cpus...)...) {
b := backends[d] b := backends[d]
bt := C.ggml_backend_get_default_buffer_type(b) bt := C.ggml_backend_get_default_buffer_type(b)

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"hash/maphash" "hash/maphash"
"io" "io"
@ -218,7 +219,7 @@ type BackendMemory struct {
} }
func (m BackendMemory) LogValue() slog.Value { func (m BackendMemory) LogValue() slog.Value {
var attrs []slog.Attr attrs := make([]slog.Attr, 0, 2+len(m.GPUs))
if m.InputWeights != 0 { if m.InputWeights != 0 {
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights)) attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
} }
@ -414,14 +415,7 @@ func LibraryPaths(l []DeviceInfo) []string {
gpuLibs := []string{LibOllamaPath} gpuLibs := []string{LibOllamaPath}
for _, gpu := range l { for _, gpu := range l {
for _, dir := range gpu.LibraryPath { for _, dir := range gpu.LibraryPath {
needed := true if !slices.Contains(gpuLibs, dir) {
for _, existing := range gpuLibs {
if dir == existing {
needed = false
break
}
}
if needed {
gpuLibs = append(gpuLibs, dir) gpuLibs = append(gpuLibs, dir)
} }
} }
@ -437,15 +431,15 @@ const (
DuplicateDevice // The same physical device but different library/backend (overlapping device) DuplicateDevice // The same physical device but different library/backend (overlapping device)
) )
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison { func (d DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
if a.PCIID != b.PCIID { if d.PCIID != b.PCIID {
return UniqueDevice return UniqueDevice
} }
// If PCIID is empty, we have to use ID + library for uniqueness // If PCIID is empty, we have to use ID + library for uniqueness
if a.PCIID == "" && a.DeviceID != b.DeviceID { if d.PCIID == "" && d.DeviceID != b.DeviceID {
return UniqueDevice return UniqueDevice
} }
if a.Library == b.Library { if d.Library == b.Library {
return SameBackendDevice return SameBackendDevice
} }
return DuplicateDevice return DuplicateDevice
@ -453,8 +447,8 @@ func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
// For a SameBackendDevice, return true if b is better than a // For a SameBackendDevice, return true if b is better than a
// e.g. newer GPU library version // e.g. newer GPU library version
func (a DeviceInfo) IsBetter(b DeviceInfo) bool { func (d DeviceInfo) IsBetter(b DeviceInfo) bool {
aLib := a.LibraryPath[len(a.LibraryPath)-1] aLib := d.LibraryPath[len(d.LibraryPath)-1]
bLib := b.LibraryPath[len(b.LibraryPath)-1] bLib := b.LibraryPath[len(b.LibraryPath)-1]
if aLib == bLib { if aLib == bLib {
return false return false
@ -481,7 +475,7 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
for _, gpu := range l { for _, gpu := range l {
supportsFA := gpu.Library == "cpu" || supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" || gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || (gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && (gpu.ComputeMajor != 7 || gpu.ComputeMinor != 2)) ||
gpu.Library == "ROCm" || gpu.Library == "ROCm" ||
gpu.Library == "Vulkan" gpu.Library == "Vulkan"
@ -549,12 +543,12 @@ func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) {
} }
v, existing := env[envVar] v, existing := env[envVar]
if existing { if existing {
v = v + "," v += ","
} }
if d.FilterID != "" { if d.FilterID != "" {
v = v + d.FilterID v += d.FilterID
} else { } else {
v = v + d.ID v += d.ID
} }
env[envVar] = v env[envVar] = v
} }
@ -594,7 +588,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, fmt.Errorf("failed to finish discovery before timeout") return nil, errors.New("failed to finish discovery before timeout")
case <-tick: case <-tick:
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil) r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
if err != nil { if err != nil {
@ -606,7 +600,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
if err != nil { if err != nil {
// slog.Warn("failed to send request", "error", err) // slog.Warn("failed to send request", "error", err)
if runner.HasExited() { if runner.HasExited() {
return nil, fmt.Errorf("runner crashed") return nil, errors.New("runner crashed")
} }
continue continue
} }
@ -614,7 +608,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
if resp.StatusCode == http.StatusNotFound { if resp.StatusCode == http.StatusNotFound {
// old runner, fall back to bootstrapping model // old runner, fall back to bootstrapping model
return nil, fmt.Errorf("llamarunner free vram reporting not supported") return nil, errors.New("llamarunner free vram reporting not supported")
} }
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)

View File

@ -143,9 +143,9 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
case r == 0x00ad: case r == 0x00ad:
r = 0x0143 r = 0x0143
case r <= 0x0020: case r <= 0x0020:
r = r + 0x0100 r += 0x0100
case r >= 0x007f && r <= 0x00a0: case r >= 0x007f && r <= 0x00a0:
r = r + 0x00a2 r += 0x00a2
} }
sb.WriteRune(r) sb.WriteRune(r)
@ -264,9 +264,9 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
case r == 0x0143: case r == 0x0143:
r = 0x00ad r = 0x00ad
case r > 0x0100 && r <= 0x0120: case r > 0x0100 && r <= 0x0120:
r = r - 0x0100 r -= 0x0100
case r > 0x0120 && r <= 0x0142: case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2 r -= 0x00a2
} }
// NOTE: not using WriteRune here because it writes the UTF-8 // NOTE: not using WriteRune here because it writes the UTF-8

View File

@ -146,7 +146,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
func modelForArch(c fs.Config) (Model, error) { func modelForArch(c fs.Config) (Model, error) {
arch := c.Architecture() arch := c.Architecture()
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone { if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed" arch += "_embed"
} }
f, ok := models[arch] f, ok := models[arch]
@ -175,9 +175,10 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
tagsCopy = append(tagsCopy, parseTag(tag)) tagsCopy = append(tagsCopy, parseTag(tag))
} }
if tt == reflect.TypeOf((*Base)(nil)).Elem() { switch {
case tt == reflect.TypeFor[Base]():
vv.Set(reflect.ValueOf(base)) vv.Set(reflect.ValueOf(base))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { case tt == reflect.TypeFor[ml.Tensor]():
var fn func([]Tag, string, string) [][]string var fn func([]Tag, string, string) [][]string
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
if len(tags) > 0 { if len(tags) > 0 {
@ -217,9 +218,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
break break
} }
} }
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { case tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface:
setPointer(base, vv, tagsCopy) setPointer(base, vv, tagsCopy)
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { case tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array:
for i := range vv.Len() { for i := range vv.Len() {
vvv := vv.Index(i) vvv := vv.Index(i)
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {

View File

@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.attnKeyLen, m.ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
} }
type MLP struct { type MLP struct {
@ -178,10 +178,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount { if len(m.Layers) == gemma27BLayerCount {
m.Options.largeModelScaling = true m.largeModelScaling = true
} }
for i, layer := range m.Layers { for i, layer := range m.Layers {
@ -202,9 +202,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenState = m.Output.Forward(ctx, hiddenState) hiddenState = m.Output.Forward(ctx, hiddenState)
// final logit softcap // final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap)) hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx) hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil return hiddenState.Scale(ctx, float64(m.finalLogitSoftcap)), nil
} }
func init() { func init() {

View File

@ -96,15 +96,15 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
f32s, err := m.ImageProcessor.ProcessImage(image) f32s, err := m.ProcessImage(image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pixelValues := ctx.Input().FromFloats(f32s, pixelValues := ctx.Input().FromFloats(f32s,
m.ImageProcessor.imageSize, m.imageSize,
m.ImageProcessor.imageSize, m.imageSize,
m.ImageProcessor.numChannels, m.numChannels,
) )
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)

View File

@ -111,12 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextConfig.ropeLocalBase ropeBase := m.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 { if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase ropeBase = m.ropeGlobalBase
} }
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil return fast.RoPE(ctx, key, shift, m.attnKeyLen, ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
} }
type TextMLP struct { type TextMLP struct {
@ -166,7 +166,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
// set image embeddings // set image embeddings
var except []int var except []int

View File

@ -53,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
MultiModalProjector: newMultiModalProjector(c), MultiModalProjector: newMultiModalProjector(c),
} }
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) m.Cache = kvcache.NewCausalCache(m.Shift)
return m, nil return m, nil
} }
@ -109,12 +109,12 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
f32s, size, err := m.ImageProcessor.ProcessImage(image) f32s, size, err := m.ProcessImage(image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.ImageProcessor.numChannels) pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.numChannels)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)

View File

@ -133,7 +133,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize) hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps) hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.eps)
// Prepare position IDs for 2D rope // Prepare position IDs for 2D rope
positions := make([]int32, numPatches) positions := make([]int32, numPatches)

View File

@ -54,7 +54,7 @@ func New(c fs.Config) (model.Model, error) {
encoderCache := kvcache.NewEncoderCache() encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{}) encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift)) m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.Shift))
return &m, nil return &m, nil
} }
@ -69,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err return nil, err
} }
f32s, ratio, err := m.ImageProcessor.ProcessImage(image) f32s, ratio, err := m.ProcessImage(image)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -223,8 +223,8 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, cros
} }
func newTextModel(c fs.Config) *TextModel { func newTextModel(c fs.Config) *TextModel {
var decoderLayers []TextDecoderLayer decoderLayers := make([]TextDecoderLayer, c.Uint("block_count"))
for i := range c.Uint("block_count") { for i := range decoderLayers {
var textDecoderLayer TextDecoderLayer var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) { if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{} textDecoderLayer = &TextCrossAttentionDecoderLayer{}
@ -232,7 +232,7 @@ func newTextModel(c fs.Config) *TextModel {
textDecoderLayer = &TextSelfAttentionDecoderLayer{} textDecoderLayer = &TextSelfAttentionDecoderLayer{}
} }
decoderLayers = append(decoderLayers, textDecoderLayer) decoderLayers[i] = textDecoderLayer
} }
return &TextModel{ return &TextModel{

View File

@ -2,6 +2,7 @@ package qwen2
import ( import (
"cmp" "cmp"
"errors"
"fmt" "fmt"
"math" "math"
"strings" "strings"
@ -130,7 +131,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
// This model currently only supports the gpt2 tokenizer // This model currently only supports the gpt2 tokenizer
if c.String("tokenizer.ggml.model") == "llama" { if c.String("tokenizer.ggml.model") == "llama" {
return nil, fmt.Errorf("unsupported tokenizer: llama") return nil, errors.New("unsupported tokenizer: llama")
} }
// detect library/qwen model(s) which are incompatible // detect library/qwen model(s) which are incompatible
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") { if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {

View File

@ -48,7 +48,7 @@ func New(c fs.Config) (model.Model, error) {
ImageProcessor: newImageProcessor(c), ImageProcessor: newImageProcessor(c),
} }
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift) m.Cache = kvcache.NewCausalCache(m.Shift)
return m, nil return m, nil
} }
@ -59,14 +59,13 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
return nil, nil, err return nil, nil, err
} }
f32s, grid, err := m.ImageProcessor.ProcessImage(image) f32s, grid, err := m.ProcessImage(image)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
// Calculate tensor dimensions // Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize * patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches) pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)

View File

@ -228,7 +228,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1)) cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1)) sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads) mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.numHeads)
// Apply encoder layers // Apply encoder layers
for i, layer := range m.Layers { for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, int32(i)) { if slices.Contains(m.fullAttnBlocks, int32(i)) {

View File

@ -107,7 +107,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid) patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err) return nil, nil, fmt.Errorf("failed to create patches: %w", err)
} }
// Return patches and grid dimensions // Return patches and grid dimensions

View File

@ -203,7 +203,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
} }
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
} }
var _ model.Model = (*Model)(nil) var _ model.Model = (*Model)(nil)

View File

@ -111,7 +111,7 @@ func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tenso
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid) patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err) return nil, nil, fmt.Errorf("failed to create patches: %w", err)
} }
patchDim := p.numChannels * p.temporalPatchSize * patchDim := p.numChannels * p.temporalPatchSize *

View File

@ -98,7 +98,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
if multiStepTool && message.Role == "user" { if multiStepTool && message.Role == "user" {
// Check if content starts with <tool_response> and ends with </tool_response> // Check if content starts with <tool_response> and ends with </tool_response>
content := r.renderContent(message) content := r.renderContent(message)
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) { if !strings.HasPrefix(content, "<tool_response>") || !strings.HasSuffix(content, "</tool_response>") {
multiStepTool = false multiStepTool = false
lastQueryIndex = i lastQueryIndex = i
} }

View File

@ -205,12 +205,12 @@ func (q queue) Less(i, j int) bool {
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
func (q *queue) Push(x interface{}) { func (q *queue) Push(x any) {
item := x.(*candidate) item := x.(*candidate)
*q = append(*q, item) *q = append(*q, item)
} }
func (q *queue) Pop() interface{} { func (q *queue) Pop() any {
old := *q old := *q
n := len(old) n := len(old)
item := old[n-1] item := old[n-1]
@ -231,7 +231,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") { if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
byteVal, err := strconv.ParseUint(data[1:5], 0, 8) byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to parse hex byte: %v", err) return "", fmt.Errorf("failed to parse hex byte: %w", err)
} }
if err := sb.WriteByte(byte(byteVal)); err != nil { if err := sb.WriteByte(byte(byteVal)); err != nil {

View File

@ -232,9 +232,9 @@ func NewError(code int, message string) ErrorResponse {
// ToUsage converts an api.ChatResponse to Usage // ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage { func ToUsage(r api.ChatResponse) Usage {
return Usage{ return Usage{
PromptTokens: r.Metrics.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
} }
} }
@ -326,9 +326,9 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
// ToUsageGenerate converts an api.GenerateResponse to Usage // ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage { func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{ return Usage{
PromptTokens: r.Metrics.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
} }
} }
@ -377,20 +377,19 @@ func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
// ToListCompletion converts an api.ListResponse to ListCompletion // ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion { func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model c := ListCompletion{Object: "list"}
for _, m := range r.Models { if len(r.Models) > 0 {
data = append(data, Model{ c.Data = make([]Model, len(r.Models))
Id: m.Name, for i, m := range r.Models {
Object: "model", c.Data[i] = Model{
Created: m.ModifiedAt.Unix(), Id: m.Name,
OwnedBy: model.ParseName(m.Name).Namespace, Object: "model",
}) Created: m.ModifiedAt.Unix(),
} OwnedBy: model.ParseName(m.Name).Namespace,
}
return ListCompletion{ }
Object: "list",
Data: data,
} }
return c
} }
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList // ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
@ -487,19 +486,14 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
} }
types := []string{"jpeg", "jpg", "png", "webp"} url, valid := strings.CutPrefix(url, "data:;base64,")
valid := false if !valid {
// support blank mime type to match api/chat taking just unadorned base64 for _, t := range []string{"jpeg", "jpg", "png", "webp"} {
if strings.HasPrefix(url, "data:;base64,") { prefix := "data:image/" + t + ";base64,"
url = strings.TrimPrefix(url, "data:;base64,") url, valid = strings.CutPrefix(url, prefix)
valid = true if valid {
} break
for _, t := range types { }
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
} }
} }

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"maps"
"net/http" "net/http"
"os" "os"
"os/user" "os/user"
@ -78,9 +79,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
if req.Files == nil { if req.Files == nil {
req.Files = digestMap req.Files = digestMap
} else { } else {
for k, v := range digestMap { maps.Copy(req.Files, digestMap)
req.Files[k] = v
}
} }
case "adapter": case "adapter":
path, err := expandPath(c.Args, relativeDir) path, err := expandPath(c.Args, relativeDir)
@ -371,7 +370,7 @@ func (e *ParserError) Error() string {
func ParseFile(r io.Reader) (*Modelfile, error) { func ParseFile(r io.Reader) (*Modelfile, error) {
var cmd Command var cmd Command
var curr state var curr state
var currLine int = 1 currLine := 1
var b bytes.Buffer var b bytes.Buffer
var role string var role string

View File

@ -326,17 +326,11 @@ MESSAGE system`,
return return
} }
switch tt.err.(type) {
case *ParserError:
var pErr *ParserError
if errors.As(err, &pErr) {
// got the correct type of error
return
}
}
if errors.Is(err, tt.err) { if errors.Is(err, tt.err) {
return return
} else if pErr := (*ParserError)(nil); errors.As(err, &pErr) {
// got the correct type of error
return
} }
t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err) t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
@ -1089,7 +1083,7 @@ func TestFilesForModel(t *testing.T) {
if err == nil { if err == nil {
t.Error("Expected error, but got none") t.Error("Expected error, but got none")
} }
if tt.expectErrType != nil && err != tt.expectErrType { if tt.expectErrType != nil && !errors.Is(err, tt.expectErrType) {
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err) t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
} }
return return

View File

@ -3,6 +3,7 @@ package readline
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"github.com/emirpasic/gods/v2/lists/arraylist" "github.com/emirpasic/gods/v2/lists/arraylist"
"github.com/mattn/go-runewidth" "github.com/mattn/go-runewidth"
@ -297,7 +298,7 @@ func (b *Buffer) drawRemaining() {
remaining := (remainingText[len(currLine):]) remaining := (remainingText[len(currLine):])
var totalLines int var totalLines int
var displayLength int var displayLength int
var lineLength int = currLineSpace lineLength := currLineSpace
for _, c := range remaining { for _, c := range remaining {
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth { if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
@ -515,13 +516,13 @@ func (b *Buffer) StringN(n int) string {
} }
func (b *Buffer) StringNM(n, m int) string { func (b *Buffer) StringNM(n, m int) string {
var s string var sb strings.Builder
if m == 0 { if m == 0 {
m = b.Buf.Size() m = b.Buf.Size()
} }
for cnt := n; cnt < m; cnt++ { for cnt := n; cnt < m; cnt++ {
c, _ := b.Buf.Get(cnt) c, _ := b.Buf.Get(cnt)
s += string(c) sb.WriteRune(c)
} }
return s return sb.String()
} }

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"strings"
) )
type Prompt struct { type Prompt struct {
@ -124,18 +125,19 @@ func (i *Instance) Readline() (string, error) {
case KeyRight: case KeyRight:
buf.MoveRight() buf.MoveRight()
case CharBracketedPaste: case CharBracketedPaste:
var code string var code strings.Builder
for range 3 { for range 3 {
r, err = i.Terminal.Read() r, err = i.Terminal.Read()
if err != nil { if err != nil {
return "", io.EOF return "", io.EOF
} }
code += string(r) code.WriteRune(r)
} }
if code == CharBracketedPasteStart { switch code.String() {
case CharBracketedPasteStart:
i.Pasting = true i.Pasting = true
} else if code == CharBracketedPasteEnd { case CharBracketedPasteEnd:
i.Pasting = false i.Pasting = false
} }
case KeyDel: case KeyDel:

View File

@ -459,10 +459,7 @@ func TestLogprobsWithStopSequences(t *testing.T) {
origLogprobsLen := len(logprobs) origLogprobsLen := len(logprobs)
numTokensRemoved := origLen - newLen numTokensRemoved := origLen - newLen
newLogprobsLen := origLogprobsLen - numTokensRemoved newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
if newLogprobsLen < 0 {
newLogprobsLen = 0
}
logprobs = logprobs[:newLogprobsLen] logprobs = logprobs[:newLogprobsLen]
// Verify responses were truncated correctly // Verify responses were truncated correctly

View File

@ -39,21 +39,15 @@ func TruncateStop(pieces []string, stop string) ([]string, bool) {
joined = joined[:index] joined = joined[:index]
// Split truncated string back into pieces of original lengths result := make([]string, 0, len(pieces))
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
tokenTruncated := false tokenTruncated := false
start := 0 start := 0
for _, length := range lengths { for _, piece := range pieces {
if start >= len(joined) { if start >= len(joined) {
break break
} }
end := start + length end := start + len(piece)
if end > len(joined) { if end > len(joined) {
end = len(joined) end = len(joined)
tokenTruncated = true tokenTruncated = true

View File

@ -61,7 +61,7 @@ func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []by
return nil, nil return nil, nil
} }
if len(data) <= 0 { if len(data) == 0 {
return nil, errors.New("received zero length image") return nil, errors.New("received zero length image")
} }

View File

@ -1,6 +1,7 @@
package llamarunner package llamarunner
import ( import (
"errors"
"reflect" "reflect"
"testing" "testing"
@ -18,7 +19,7 @@ func TestImageCache(t *testing.T) {
// Empty cache // Empty cache
result, err := cache.findImage(0x5adb61d31933a946) result, err := cache.findImage(0x5adb61d31933a946)
if err != errImageNotFound { if !errors.Is(err, errImageNotFound) {
t.Errorf("found result in empty cache: result %v, err %v", result, err) t.Errorf("found result in empty cache: result %v, err %v", result, err)
} }

View File

@ -577,10 +577,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if seq.logprobs { if seq.logprobs {
origLogprobsLen := len(seq.pendingLogprobs) origLogprobsLen := len(seq.pendingLogprobs)
numTokensRemoved := origLen - newLen numTokensRemoved := origLen - newLen
newLogprobsLen := origLogprobsLen - numTokensRemoved newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
if newLogprobsLen < 0 {
newLogprobsLen = 0
}
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen] seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
} }
@ -998,7 +995,6 @@ func Execute(args []string) error {
log.Println("Server listening on", addr) log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil { if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
return err return err
} }

View File

@ -2,7 +2,6 @@ package ollamarunner
import ( import (
"errors" "errors"
"fmt"
"slices" "slices"
"testing" "testing"
"time" "time"
@ -511,7 +510,7 @@ type mockCache struct {
// Implement only the methods needed for the test // Implement only the methods needed for the test
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error { func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
if m.shouldFail { if m.shouldFail {
return fmt.Errorf("mock cache removal error") return errors.New("mock cache removal error")
} }
return nil return nil
} }

View File

@ -801,10 +801,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq.logprobs { if seq.logprobs {
origLogprobsLen := len(seq.pendingLogprobs) origLogprobsLen := len(seq.pendingLogprobs)
numTokensRemoved := origLen - newLen numTokensRemoved := origLen - newLen
newLogprobsLen := origLogprobsLen - numTokensRemoved newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
if newLogprobsLen < 0 {
newLogprobsLen = 0
}
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen] seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
} }
@ -1242,7 +1239,7 @@ func (s *Server) loadModel() {
s.progress = progress s.progress = progress
}) })
if err != nil { if err != nil {
panic(fmt.Errorf("failed to load model: %v", err)) panic(fmt.Errorf("failed to load model: %w", err))
} }
s.status = llm.ServerStatusReady s.status = llm.ServerStatusReady
@ -1432,7 +1429,6 @@ func Execute(args []string) error {
log.Println("Server listening on", addr) log.Println("Server listening on", addr)
if err := httpServer.Serve(listener); err != nil { if err := httpServer.Serve(listener); err != nil {
log.Fatal("server error:", err)
return err return err
} }

View File

@ -30,7 +30,7 @@ func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability // Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7) temp = max(temp, 1e-7)
for i := range ts { for i := range ts {
ts[i].value = ts[i].value / temp ts[i].value /= temp
} }
} }

View File

@ -33,7 +33,7 @@ func (r registryChallenge) URL() (*url.URL, error) {
values := redirectURL.Query() values := redirectURL.Query()
values.Add("service", r.Service) values.Add("service", r.Service)
for _, s := range strings.Split(r.Scope, " ") { for s := range strings.SplitSeq(r.Scope, " ") {
values.Add("scope", s) values.Add("scope", s)
} }
@ -57,7 +57,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
} }
sha256sum := sha256.Sum256(nil) sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))) data := fmt.Appendf(nil, "%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))
headers := make(http.Header) headers := make(http.Header)
signature, err := auth.Sign(ctx, data) signature, err := auth.Sign(ctx, data)
@ -75,7 +75,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
body, err := io.ReadAll(response.Body) body, err := io.ReadAll(response.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("%d: %v", response.StatusCode, err) return "", fmt.Errorf("%d: %w", response.StatusCode, err)
} }
if response.StatusCode >= http.StatusBadRequest { if response.StatusCode >= http.StatusBadRequest {

View File

@ -386,7 +386,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
} }
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) { if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root // Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp) return nil, fmt.Errorf("%w: %w: %s", errFilePath, err, fp)
} }
blobPath, err := GetBlobsPath(digest) blobPath, err := GetBlobsPath(digest)
@ -456,15 +456,15 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
return l.KV(), nil return l.KV(), nil
} }
} }
return ggml.KV{}, fmt.Errorf("no base model was found") return ggml.KV{}, errors.New("no base model was found")
} }
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) { func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []Layer layers := make([]Layer, len(baseLayers))
for _, layer := range baseLayers { for i, layer := range baseLayers {
if layer.GGML != nil { if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization)) quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" { if quantType != "" && layer.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
want, err := ggml.ParseFileType(quantType) want, err := ggml.ParseFileType(quantType)
if err != nil { if err != nil {
return err return err
@ -480,13 +480,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
} }
} }
} }
config.ModelFormat = cmp.Or(config.ModelFormat, layer.GGML.Name()) config.ModelFormat = cmp.Or(config.ModelFormat, layer.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture()) config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount())) config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String()) config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture()) config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
} }
layers = append(layers, layer.Layer) layers[i] = layer.Layer
} }
if r.Template != "" { if r.Template != "" {
@ -678,10 +678,10 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
func setTemplate(layers []Layer, t string) ([]Layer, error) { func setTemplate(layers []Layer, t string) ([]Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template") layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil { if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err) return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
} }
if _, err := template.Parse(t); err != nil { if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err) return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
} }
blob := strings.NewReader(t) blob := strings.NewReader(t)

View File

@ -640,7 +640,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
manifest, err = pullModelManifest(ctx, mp, regOpts) manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil { if err != nil {
return fmt.Errorf("pull model manifest: %s", err) return fmt.Errorf("pull model manifest: %w", err)
} }
var layers []Layer var layers []Layer
@ -786,7 +786,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
defer resp.Body.Close() defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err) return nil, fmt.Errorf("%d: %w", resp.StatusCode, err)
} }
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody) return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default: default:

View File

@ -438,7 +438,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
// last write. check hash. // last write. check hash.
sum := w.h.Sum(nil) sum := w.h.Sum(nil)
if !bytes.Equal(sum, w.d.sum[:]) { if !bytes.Equal(sum, w.d.sum[:]) {
return 0, w.seterr(fmt.Errorf("file content changed underfoot")) return 0, w.seterr(errors.New("file content changed underfoot"))
} }
if w.testHookBeforeFinalWrite != nil { if w.testHookBeforeFinalWrite != nil {
w.testHookBeforeFinalWrite(w.f) w.testHookBeforeFinalWrite(w.f)

View File

@ -84,8 +84,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
// TODO(bmizerany): Print platform-specific instructions or // TODO(bmizerany): Print platform-specific instructions or
// link to docs on that topic. // link to docs on that topic.
lines := strings.Split(volumeHint, "\n") for line := range strings.SplitSeq(volumeHint, "\n") {
for _, line := range lines {
t.Skip(line) t.Skip(line)
} }
} }

View File

@ -60,7 +60,7 @@ func (d Digest) String() string {
} }
func (d Digest) Short() string { func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4]) return hex.EncodeToString(d.sum[:4])
} }
func (d Digest) Sum() [32]byte { func (d Digest) Sum() [32]byte {

View File

@ -1184,11 +1184,11 @@ func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
} }
start, err := strconv.ParseInt(startPart, 10, 64) start, err := strconv.ParseInt(startPart, 10, 64)
if err != nil { if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err) return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %w", s, err)
} }
end, err := strconv.ParseInt(endPart, 10, 64) end, err := strconv.ParseInt(endPart, 10, 64)
if err != nil { if err != nil {
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err) return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %w", s, err)
} }
if start > end { if start > end {
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s) return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)

View File

@ -142,7 +142,7 @@ var junkName Name
func BenchmarkParseName(b *testing.B) { func BenchmarkParseName(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for range b.N { for b.Loop() {
junkName = Parse("h/n/m:t") junkName = Parse("h/n/m:t")
} }
} }

View File

@ -187,15 +187,15 @@ func (w *relayWriter) Close() error {
return nil return nil
} }
func (t *relayWriter) awaitTurn() (ok bool) { func (w *relayWriter) awaitTurn() (ok bool) {
if t.ready { if w.ready {
return true return true
} }
select { select {
case <-t.t.Ready(): case <-w.t.Ready():
t.ready = true w.ready = true
return true return true
case <-t.q.closed(): case <-w.q.closed():
return false return false
} }
} }

View File

@ -251,7 +251,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
type progressUpdateJSON struct { type progressUpdateJSON struct {
Error string `json:"error,omitempty,omitzero"` Error string `json:"error,omitempty,omitzero"`
Status string `json:"status,omitempty,omitzero"` Status string `json:"status,omitempty,omitzero"`
Digest blob.Digest `json:"digest,omitempty,omitzero"` Digest blob.Digest `json:"digest,omitzero"`
Total int64 `json:"total,omitempty,omitzero"` Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"` Completed int64 `json:"completed,omitempty,omitzero"`
} }

View File

@ -74,7 +74,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
return "", nil, errors.New("this model only supports one image while more than one image requested") return "", nil, errors.New("this model only supports one image while more than one image requested")
} }
var prefix string var prefix strings.Builder
prompt := msg.Content prompt := msg.Content
for _, i := range msg.Images { for _, i := range msg.Images {
@ -85,14 +85,14 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
imgTag := fmt.Sprintf("[img-%d]", imgData.ID) imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") { if !strings.Contains(prompt, "[img]") {
prefix += imgTag prefix.WriteString(imgTag)
} else { } else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1) prompt = strings.Replace(prompt, "[img]", imgTag, 1)
} }
images = append(images, imgData) images = append(images, imgData)
} }
msgs[currMsgIdx+cnt].Content = prefix + prompt msgs[currMsgIdx+cnt].Content = prefix.String() + prompt
} }
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window

View File

@ -2,6 +2,7 @@ package server
import ( import (
"bytes" "bytes"
"errors"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -238,7 +239,7 @@ func TestChatPrompt(t *testing.T) {
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate) prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
if tt.error == nil && err != nil { if tt.error == nil && err != nil {
t.Fatal(err) t.Fatal(err)
} else if tt.error != nil && err != tt.error { } else if tt.error != nil && !errors.Is(err, tt.error) {
t.Fatalf("expected err '%q', got '%q'", tt.error, err) t.Fatalf("expected err '%q', got '%q'", tt.error, err)
} }

View File

@ -31,7 +31,7 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
data, err := io.ReadAll(sr) data, err := io.ReadAll(sr)
if err != nil { if err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err) slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err) return 0, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
} }
var f32s []float32 var f32s []float32
newType := fsggml.TensorType(q.to.Kind) newType := fsggml.TensorType(q.to.Kind)

View File

@ -12,6 +12,7 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"maps"
"math" "math"
"math/rand" "math/rand"
"net" "net"
@ -129,7 +130,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
} }
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 { if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'") return nil, nil, nil, errors.New("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
} }
if err := model.CheckCapabilities(caps...); err != nil { if err := model.CheckCapabilities(caps...); err != nil {
@ -361,11 +362,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Think == nil { if req.Think == nil {
req.Think = &api.ThinkValue{Value: true} req.Think = &api.ThinkValue{Value: true}
} }
} else { } else if req.Think != nil && req.Think.Bool() {
if req.Think != nil && req.Think.Bool() { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)}) return
return
}
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@ -649,10 +648,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
truncate := true truncate := req.Truncate == nil || *req.Truncate
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string var input []string
@ -825,9 +821,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
var e []float64 e := make([]float64, len(embedding))
for _, v := range embedding { for i, v := range embedding {
e = append(e, float64(v)) e[i] = float64(v)
} }
resp := api.EmbeddingResponse{ resp := api.EmbeddingResponse{
@ -1139,9 +1135,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if m.Options == nil { if m.Options == nil {
m.Options = make(map[string]any) m.Options = make(map[string]any)
} }
for k, v := range req.Options { maps.Copy(m.Options, req.Options)
m.Options[k] = v
}
} }
var sb strings.Builder var sb strings.Builder
@ -1212,7 +1206,7 @@ func (s *Server) ListHandler(c *gin.Context) {
return return
} }
models := []api.ListModelResponse{} models := make([]api.ListModelResponse, 0, len(ms))
for n, m := range ms { for n, m := range ms {
var cf ConfigV2 var cf ConfigV2
@ -1811,13 +1805,13 @@ func (s *Server) PsHandler(c *gin.Context) {
ExpiresAt: v.expiresAt, ExpiresAt: v.expiresAt,
} }
if v.Options != nil { if v.Options != nil {
mr.ContextLength = v.Options.NumCtx mr.ContextLength = v.NumCtx
} }
// The scheduler waits to set expiresAt, so if a model is loading it's // The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just // possible that it will be set to the unix epoch. For those cases, just
// calculate the time w/ the sessionDuration instead. // calculate the time w/ the sessionDuration instead.
var epoch time.Time var epoch time.Time
if v.expiresAt == epoch { if v.expiresAt.Equal(epoch) {
mr.ExpiresAt = time.Now().Add(v.sessionDuration) mr.ExpiresAt = time.Now().Add(v.sessionDuration)
} }
@ -2000,11 +1994,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Think == nil { if req.Think == nil {
req.Think = &api.ThinkValue{Value: true} req.Think = &api.ThinkValue{Value: true}
} }
} else { } else if req.Think != nil && req.Think.Bool() {
if req.Think != nil && req.Think.Bool() { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)}) return
return
}
} }
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)

View File

@ -196,11 +196,9 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages { if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount) t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
} }
} else { } else if w.Code != http.StatusOK {
// When debug is disabled, it should attempt normal processing // When debug is disabled, it should attempt normal processing
if w.Code != http.StatusOK { t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
} }
}) })
} }
@ -401,11 +399,9 @@ func TestChatDebugRenderOnly(t *testing.T) {
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages { if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount) t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
} }
} else { } else if w.Code != http.StatusOK {
// When debug is disabled, it should attempt normal processing // When debug is disabled, it should attempt normal processing
if w.Code != http.StatusOK { t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
} }
}) })
} }

View File

@ -93,7 +93,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
t.Fatalf("expected status 200, got %d", w.Code) t.Fatalf("expected status 200, got %d", w.Code)
} }
mock.CompletionResponse.Content = "Hi!" mock.Content = "Hi!"
t.Run("chat-like flow uses renderer", func(t *testing.T) { t.Run("chat-like flow uses renderer", func(t *testing.T) {
// Test that when using messages (chat-like flow), the built-in renderer is used // Test that when using messages (chat-like flow), the built-in renderer is used
@ -109,12 +109,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags // The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
// When messages are built internally from prompt, it should use the renderer // When messages are built internally from prompt, it should use the renderer
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") { if !strings.Contains(mock.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt) t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.Prompt)
} }
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") { if !strings.Contains(mock.Prompt, "<|im_end|>") {
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt) t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.Prompt)
} }
}) })
@ -132,12 +132,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
} }
// Should contain the system message and use renderer format // Should contain the system message and use renderer format
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") { if !strings.Contains(mock.Prompt, "<|im_start|>system") {
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt) t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.Prompt)
} }
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") { if !strings.Contains(mock.Prompt, "You are a helpful coding assistant.") {
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt) t.Errorf("expected prompt to contain system message content, got: %s", mock.Prompt)
} }
}) })
@ -155,12 +155,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
} }
// Should NOT use the renderer format when custom template is provided // Should NOT use the renderer format when custom template is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") { if strings.Contains(mock.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt) t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.Prompt)
} }
// Should just be the raw prompt from the template // Should just be the raw prompt from the template
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" { if diff := cmp.Diff(mock.Prompt, "Write a hello world function"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
@ -191,12 +191,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
} }
// Should NOT use the renderer format when suffix is provided // Should NOT use the renderer format when suffix is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") { if strings.Contains(mock.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt) t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.Prompt)
} }
// Should use the suffix template format // Should use the suffix template format
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" { if diff := cmp.Diff(mock.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })

View File

@ -41,7 +41,7 @@ func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn
} }
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) { func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) { for range strings.FieldsSeq(s) {
tokens = append(tokens, len(tokens)) tokens = append(tokens, len(tokens))
} }
@ -378,7 +378,7 @@ func TestGenerateChat(t *testing.T) {
} }
} }
mock.CompletionResponse.Content = "Hi!" mock.Content = "Hi!"
t.Run("messages", func(t *testing.T) { t.Run("messages", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{ w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test", Model: "test",
@ -392,7 +392,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" { if diff := cmp.Diff(mock.Prompt, "user: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -422,14 +422,14 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" { if diff := cmp.Diff(mock.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
checkChatResponse(t, w.Body, "test-system", "Hi!") checkChatResponse(t, w.Body, "test-system", "Hi!")
}) })
mock.CompletionResponse.Content = "Abra kadabra!" mock.Content = "Abra kadabra!"
t.Run("messages with system", func(t *testing.T) { t.Run("messages with system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{ w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system", Model: "test-system",
@ -444,7 +444,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" { if diff := cmp.Diff(mock.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -467,7 +467,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" { if diff := cmp.Diff(mock.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -985,7 +985,7 @@ func TestGenerate(t *testing.T) {
} }
} }
mock.CompletionResponse.Content = "Hi!" mock.Content = "Hi!"
t.Run("prompt", func(t *testing.T) { t.Run("prompt", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test", Model: "test",
@ -997,7 +997,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" { if diff := cmp.Diff(mock.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -1025,14 +1025,14 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" { if diff := cmp.Diff(mock.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
checkGenerateResponse(t, w.Body, "test-system", "Hi!") checkGenerateResponse(t, w.Body, "test-system", "Hi!")
}) })
mock.CompletionResponse.Content = "Abra kadabra!" mock.Content = "Abra kadabra!"
t.Run("prompt with system", func(t *testing.T) { t.Run("prompt with system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system", Model: "test-system",
@ -1045,7 +1045,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" { if diff := cmp.Diff(mock.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -1067,7 +1067,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" { if diff := cmp.Diff(mock.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@ -1097,7 +1097,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" { if diff := cmp.Diff(mock.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
@ -1112,7 +1112,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" { if diff := cmp.Diff(mock.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })
@ -1129,7 +1129,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" { if diff := cmp.Diff(mock.Prompt, "Help me write tests."); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })

Some files were not shown because too many files have changed in this diff Show More