Compare commits
8 Commits
main
...
mxyng/lint
| Author | SHA1 | Date |
|---|---|---|
|
|
b59053a883 | |
|
|
bb93e5afe7 | |
|
|
4d24d8a77d | |
|
|
f01c83ed6d | |
|
|
d3228355be | |
|
|
78a75a30d8 | |
|
|
974ae8ef84 | |
|
|
efd9f5e67e |
|
|
@ -36,6 +36,12 @@ linters:
|
||||||
errcheck:
|
errcheck:
|
||||||
exclude-functions:
|
exclude-functions:
|
||||||
- fmt.Fprintf
|
- fmt.Fprintf
|
||||||
|
gocritic:
|
||||||
|
disabled-checks:
|
||||||
|
# Detects suspicious duplicated sub-expressions.
|
||||||
|
# Prone to false positives when used on cgo code
|
||||||
|
# https://github.com/go-critic/go-critic/issues/897#issuecomment-568892104
|
||||||
|
- dupSubExpr
|
||||||
perfsprint:
|
perfsprint:
|
||||||
strconcat: false
|
strconcat: false
|
||||||
concat-loop: false
|
concat-loop: false
|
||||||
|
|
@ -45,24 +51,22 @@ linters:
|
||||||
# Using a deprecated function, variable, constant or field.
|
# Using a deprecated function, variable, constant or field.
|
||||||
# https://staticcheck.dev/docs/checks/#SA1019
|
# https://staticcheck.dev/docs/checks/#SA1019
|
||||||
- -SA1019
|
- -SA1019
|
||||||
# Incorrect or missing package comment.
|
|
||||||
# https://staticcheck.dev/docs/checks/#ST1000
|
|
||||||
- -ST1000
|
|
||||||
# Poorly chosen identifier.
|
# Poorly chosen identifier.
|
||||||
# https://staticcheck.dev/docs/checks/#ST1003
|
# https://staticcheck.dev/docs/checks/#ST1003
|
||||||
- -ST1003
|
- -ST1003
|
||||||
# The documentation of an exported function should start with the function's name.
|
|
||||||
# https://staticcheck.dev/docs/checks/#ST1020
|
|
||||||
- -ST1020
|
|
||||||
# The documentation of an exported type should start with type's name.
|
|
||||||
# https://staticcheck.dev/docs/checks/#ST1021
|
|
||||||
- -ST1021
|
|
||||||
# The documentation of an exported variable or constant should start with variable's name.
|
|
||||||
# https://staticcheck.dev/docs/checks/#ST1022
|
|
||||||
- -ST1022
|
|
||||||
usestdlibvars:
|
usestdlibvars:
|
||||||
http-method: false
|
http-method: false
|
||||||
http-status-code: false
|
http-status-code: false
|
||||||
|
exclusions:
|
||||||
|
presets:
|
||||||
|
- comments
|
||||||
|
- common-false-positives
|
||||||
|
- legacy
|
||||||
|
- std-error-handling
|
||||||
|
rules:
|
||||||
|
- path: _test\.go
|
||||||
|
linters:
|
||||||
|
- prealloc
|
||||||
|
|
||||||
formatters:
|
formatters:
|
||||||
enable:
|
enable:
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
|
@ -39,7 +40,7 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||||
t.Setenv("OLLAMA_HOST", v.value)
|
t.Setenv("OLLAMA_HOST", v.value)
|
||||||
|
|
||||||
client, err := ClientFromEnvironment()
|
client, err := ClientFromEnvironment()
|
||||||
if err != v.err {
|
if !errors.Is(err, v.err) {
|
||||||
t.Fatalf("expected %s, got %s", v.err, err)
|
t.Fatalf("expected %s, got %s", v.err, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
25
api/types.go
25
api/types.go
|
|
@ -2,6 +2,7 @@ package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
|
|
@ -308,9 +309,9 @@ func (tp ToolProperty) ToTypeScriptType() string {
|
||||||
return mapToTypeScriptType(tp.Type[0])
|
return mapToTypeScriptType(tp.Type[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
var types []string
|
types := make([]string, len(tp.Type))
|
||||||
for _, t := range tp.Type {
|
for i, t := range tp.Type {
|
||||||
types = append(types, mapToTypeScriptType(t))
|
types[i] = mapToTypeScriptType(t)
|
||||||
}
|
}
|
||||||
return strings.Join(types, " | ")
|
return strings.Join(types, " | ")
|
||||||
}
|
}
|
||||||
|
|
@ -783,7 +784,7 @@ func (m *Metrics) Summary() {
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]any) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
|
||||||
|
|
||||||
// build map of json struct tags to their types
|
// build map of json struct tags to their types
|
||||||
jsonOpts := make(map[string]reflect.StructField)
|
jsonOpts := make(map[string]reflect.StructField)
|
||||||
|
|
@ -854,8 +855,7 @@ func (opts *Options) FromMap(m map[string]any) error {
|
||||||
}
|
}
|
||||||
field.Set(reflect.ValueOf(slice))
|
field.Set(reflect.ValueOf(slice))
|
||||||
case reflect.Pointer:
|
case reflect.Pointer:
|
||||||
var b bool
|
if field.Type() == reflect.TypeFor[*bool]() {
|
||||||
if field.Type() == reflect.TypeOf(&b) {
|
|
||||||
val, ok := val.(bool)
|
val, ok := val.(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("option %q must be of type boolean", key)
|
return fmt.Errorf("option %q must be of type boolean", key)
|
||||||
|
|
@ -906,7 +906,7 @@ func DefaultOptions() Options {
|
||||||
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
||||||
type ThinkValue struct {
|
type ThinkValue struct {
|
||||||
// Value can be a bool or string
|
// Value can be a bool or string
|
||||||
Value interface{}
|
Value any
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValid checks if the ThinkValue is valid
|
// IsValid checks if the ThinkValue is valid
|
||||||
|
|
@ -999,7 +999,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
return errors.New("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON implements json.Marshaler
|
// MarshalJSON implements json.Marshaler
|
||||||
|
|
@ -1018,7 +1018,7 @@ func (d Duration) MarshalJSON() ([]byte, error) {
|
||||||
if d.Duration < 0 {
|
if d.Duration < 0 {
|
||||||
return []byte("-1"), nil
|
return []byte("-1"), nil
|
||||||
}
|
}
|
||||||
return []byte("\"" + d.Duration.String() + "\""), nil
|
return []byte("\"" + d.String() + "\""), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
|
|
@ -1045,7 +1045,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
d.Duration = time.Duration(math.MaxInt64)
|
d.Duration = time.Duration(math.MaxInt64)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
|
return fmt.Errorf("unsupported type: '%s'", reflect.TypeOf(v))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -1055,7 +1055,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||||
func FormatParams(params map[string][]string) (map[string]any, error) {
|
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
typeOpts := reflect.TypeFor[Options]() // types of the fields in the options struct
|
||||||
|
|
||||||
// build map of json struct tags to their types
|
// build map of json struct tags to their types
|
||||||
jsonOpts := make(map[string]reflect.StructField)
|
jsonOpts := make(map[string]reflect.StructField)
|
||||||
|
|
@ -1102,8 +1102,7 @@ func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||||
// TODO: only string slices are supported right now
|
// TODO: only string slices are supported right now
|
||||||
out[key] = vals
|
out[key] = vals
|
||||||
case reflect.Pointer:
|
case reflect.Pointer:
|
||||||
var b bool
|
if field.Type() == reflect.TypeFor[*bool]() {
|
||||||
if field.Type() == reflect.TypeOf(&b) {
|
|
||||||
boolVal, err := strconv.ParseBool(vals[0])
|
boolVal, err := strconv.ParseBool(vals[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid bool value %s", vals)
|
return nil, fmt.Errorf("invalid bool value %s", vals)
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ import (
|
||||||
var ErrCancelled = errors.New("Cancelled")
|
var ErrCancelled = errors.New("Cancelled")
|
||||||
|
|
||||||
// Cancelled refers to ErrCancelled.
|
// Cancelled refers to ErrCancelled.
|
||||||
|
//
|
||||||
// Deprecated: Use ErrCancelled instead.
|
// Deprecated: Use ErrCancelled instead.
|
||||||
var Cancelled = ErrCancelled
|
var Cancelled = ErrCancelled
|
||||||
|
|
||||||
|
|
@ -37,7 +38,7 @@ type MsgBuilder struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message initialises a MsgBuilder with the provided message.
|
// Message initialises a MsgBuilder with the provided message.
|
||||||
func Message(format string, args ...interface{}) *MsgBuilder {
|
func Message(format string, args ...any) *MsgBuilder {
|
||||||
return &MsgBuilder{Msg: fmt.Sprintf(format, args...)}
|
return &MsgBuilder{Msg: fmt.Sprintf(format, args...)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -319,7 +319,7 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("timeout scanning server log for inference compute details")
|
return nil, errors.New("timeout scanning server log for inference compute details")
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
file, err := os.Open(serverLogPath)
|
file, err := os.Open(serverLogPath)
|
||||||
|
|
@ -345,11 +345,9 @@ func GetInferenceComputer(ctx context.Context) ([]InferenceCompute, error) {
|
||||||
|
|
||||||
slog.Info("Matched", "inference compute", ic)
|
slog.Info("Matched", "inference compute", ic)
|
||||||
inference = append(inference, ic)
|
inference = append(inference, ic)
|
||||||
} else {
|
} else if len(inference) > 0 {
|
||||||
// Break out on first non matching line after we start matching
|
// Break out on first non matching line after we start matching
|
||||||
if len(inference) > 0 {
|
return inference, nil
|
||||||
return inference, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ func terminate(proc *os.Process) error {
|
||||||
func terminated(pid int) (bool, error) {
|
func terminated(pid int) (bool, error) {
|
||||||
proc, err := os.FindProcess(pid)
|
proc, err := os.FindProcess(pid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to find process: %v", err)
|
return false, fmt.Errorf("failed to find process: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = proc.Signal(syscall.Signal(0))
|
err = proc.Signal(syscall.Signal(0))
|
||||||
|
|
@ -40,7 +40,7 @@ func terminated(pid int) (bool, error) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, fmt.Errorf("error signaling process: %v", err)
|
return false, fmt.Errorf("error signaling process: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
|
|
@ -67,8 +67,7 @@ func reapServers() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
pids := strings.Split(pidsStr, "\n")
|
for pidStr := range strings.SplitSeq(pidsStr, "\n") {
|
||||||
for _, pidStr := range pids {
|
|
||||||
pidStr = strings.TrimSpace(pidStr)
|
pidStr = strings.TrimSpace(pidStr)
|
||||||
if pidStr == "" {
|
if pidStr == "" {
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ package store
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -482,7 +483,8 @@ func (db *database) cleanupOrphanedData() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func duplicateColumnError(err error) bool {
|
func duplicateColumnError(err error) bool {
|
||||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
var sqlite3Err sqlite3.Error
|
||||||
|
if errors.As(err, &sqlite3Err) {
|
||||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||||
}
|
}
|
||||||
|
|
@ -490,7 +492,8 @@ func duplicateColumnError(err error) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func columnNotExists(err error) bool {
|
func columnNotExists(err error) bool {
|
||||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
var sqlite3Err sqlite3.Error
|
||||||
|
if errors.As(err, &sqlite3Err) {
|
||||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||||
}
|
}
|
||||||
|
|
@ -586,8 +589,8 @@ func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Cha
|
||||||
&browserState,
|
&browserState,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return nil, fmt.Errorf("chat not found")
|
return nil, errors.New("chat not found")
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("query chat: %w", err)
|
return nil, fmt.Errorf("query chat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -752,7 +755,7 @@ func (db *database) updateLastMessage(chatID string, msg Message) error {
|
||||||
return fmt.Errorf("get rows affected: %w", err)
|
return fmt.Errorf("get rows affected: %w", err)
|
||||||
}
|
}
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return fmt.Errorf("no message found to update")
|
return errors.New("no message found to update")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = tx.Exec("DELETE FROM attachments WHERE message_id = ?", messageID)
|
_, err = tx.Exec("DELETE FROM attachments WHERE message_id = ?", messageID)
|
||||||
|
|
|
||||||
|
|
@ -282,7 +282,7 @@ func countRows(t *testing.T, db *database, table string) int {
|
||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...interface{}) int {
|
func countRowsWithCondition(t *testing.T, db *database, table, condition string, args ...any) int {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
var count int
|
var count int
|
||||||
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition)
|
query := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s", table, condition)
|
||||||
|
|
@ -296,7 +296,7 @@ func countRowsWithCondition(t *testing.T, db *database, table, condition string,
|
||||||
// Test helpers for schema migration testing
|
// Test helpers for schema migration testing
|
||||||
|
|
||||||
// schemaMap returns both tables/columns and indexes (ignoring order)
|
// schemaMap returns both tables/columns and indexes (ignoring order)
|
||||||
func schemaMap(db *database) map[string]interface{} {
|
func schemaMap(db *database) map[string]any {
|
||||||
result := make(map[string]any)
|
result := make(map[string]any)
|
||||||
|
|
||||||
result["tables"] = columnMap(db)
|
result["tables"] = columnMap(db)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ package store
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
@ -26,7 +27,7 @@ func (i *Image) Bytes() ([]byte, error) {
|
||||||
// ImgBytes reads image data from the specified file path
|
// ImgBytes reads image data from the specified file path
|
||||||
func ImgBytes(path string) ([]byte, error) {
|
func ImgBytes(path string) ([]byte, error) {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
return nil, fmt.Errorf("empty image path")
|
return nil, errors.New("empty image path")
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
|
@ -130,7 +131,7 @@ func (b *BrowserSearch) Schema() map[string]any {
|
||||||
func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||||
query, ok := args["query"].(string)
|
query, ok := args["query"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, "", fmt.Errorf("query parameter is required")
|
return nil, "", errors.New("query parameter is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
topn, ok := args["topn"].(int)
|
topn, ok := args["topn"].(int)
|
||||||
|
|
@ -150,7 +151,7 @@ func (b *BrowserSearch) Execute(ctx context.Context, args map[string]any) (any,
|
||||||
|
|
||||||
searchResponse, ok := result.(*WebSearchResponse)
|
searchResponse, ok := result.(*WebSearchResponse)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, "", fmt.Errorf("invalid search results format")
|
return nil, "", errors.New("invalid search results format")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build main search results page that contains all search results
|
// Build main search results page that contains all search results
|
||||||
|
|
@ -383,15 +384,9 @@ func wrapLines(text string, width int) []string {
|
||||||
wrapped = append(wrapped, "")
|
wrapped = append(wrapped, "")
|
||||||
} else if len(line) <= width {
|
} else if len(line) <= width {
|
||||||
wrapped = append(wrapped, line)
|
wrapped = append(wrapped, line)
|
||||||
|
} else if words := strings.Fields(line); len(words) == 0 {
|
||||||
|
wrapped = append(wrapped, line)
|
||||||
} else {
|
} else {
|
||||||
// Word wrapping while preserving whitespace structure
|
|
||||||
words := strings.Fields(line)
|
|
||||||
if len(words) == 0 {
|
|
||||||
// Line with only whitespace
|
|
||||||
wrapped = append(wrapped, line)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
currentLine := ""
|
currentLine := ""
|
||||||
for _, word := range words {
|
for _, word := range words {
|
||||||
// Check if adding this word would exceed width
|
// Check if adding this word would exceed width
|
||||||
|
|
@ -536,15 +531,13 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
||||||
}
|
}
|
||||||
} else {
|
} else if len(b.state.Data.PageStack) != 0 {
|
||||||
// get last page
|
// get last page
|
||||||
if len(b.state.Data.PageStack) != 0 {
|
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
|
||||||
pageURL := b.state.Data.PageStack[len(b.state.Data.PageStack)-1]
|
var err error
|
||||||
var err error
|
page, err = b.getPageFromStack(pageURL)
|
||||||
page, err = b.getPageFromStack(pageURL)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
||||||
return nil, "", fmt.Errorf("page not found for cursor %d: %w", cursor, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -594,7 +587,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
|
||||||
// Try to get id as integer (link ID from current page)
|
// Try to get id as integer (link ID from current page)
|
||||||
if id, ok := args["id"].(float64); ok {
|
if id, ok := args["id"].(float64); ok {
|
||||||
if page == nil {
|
if page == nil {
|
||||||
return nil, "", fmt.Errorf("no current page to resolve link from")
|
return nil, "", errors.New("no current page to resolve link from")
|
||||||
}
|
}
|
||||||
idInt := int(id)
|
idInt := int(id)
|
||||||
pageURL, ok := page.Links[idInt]
|
pageURL, ok := page.Links[idInt]
|
||||||
|
|
@ -637,7 +630,7 @@ func (b *BrowserOpen) Execute(ctx context.Context, args map[string]any) (any, st
|
||||||
|
|
||||||
// If no id provided, just display current page
|
// If no id provided, just display current page
|
||||||
if page == nil {
|
if page == nil {
|
||||||
return nil, "", fmt.Errorf("no current page to display")
|
return nil, "", errors.New("no current page to display")
|
||||||
}
|
}
|
||||||
// Only add to PageStack without updating URLToPage
|
// Only add to PageStack without updating URLToPage
|
||||||
b.state.Data.PageStack = append(b.state.Data.PageStack, page.URL)
|
b.state.Data.PageStack = append(b.state.Data.PageStack, page.URL)
|
||||||
|
|
@ -742,7 +735,7 @@ func (b *BrowserFind) Schema() map[string]any {
|
||||||
func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||||
pattern, ok := args["pattern"].(string)
|
pattern, ok := args["pattern"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, "", fmt.Errorf("pattern parameter is required")
|
return nil, "", errors.New("pattern parameter is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cursor parameter if provided, default to current page
|
// Get cursor parameter if provided, default to current page
|
||||||
|
|
@ -756,7 +749,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
|
||||||
if cursor == -1 {
|
if cursor == -1 {
|
||||||
// Use current page
|
// Use current page
|
||||||
if len(b.state.Data.PageStack) == 0 {
|
if len(b.state.Data.PageStack) == 0 {
|
||||||
return nil, "", fmt.Errorf("no pages to search in")
|
return nil, "", errors.New("no pages to search in")
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
page, err = b.getPageFromStack(b.state.Data.PageStack[len(b.state.Data.PageStack)-1])
|
page, err = b.getPageFromStack(b.state.Data.PageStack[len(b.state.Data.PageStack)-1])
|
||||||
|
|
@ -776,7 +769,7 @@ func (b *BrowserFind) Execute(ctx context.Context, args map[string]any) (any, st
|
||||||
}
|
}
|
||||||
|
|
||||||
if page == nil {
|
if page == nil {
|
||||||
return nil, "", fmt.Errorf("page not found")
|
return nil, "", errors.New("page not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create find results page
|
// Create find results page
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ package tools
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -87,7 +88,7 @@ func (g *BrowserCrawler) Schema() map[string]any {
|
||||||
func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) {
|
func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*CrawlResponse, error) {
|
||||||
urlsRaw, ok := args["urls"].([]any)
|
urlsRaw, ok := args["urls"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("urls parameter is required and must be an array of strings")
|
return nil, errors.New("urls parameter is required and must be an array of strings")
|
||||||
}
|
}
|
||||||
|
|
||||||
urls := make([]string, 0, len(urlsRaw))
|
urls := make([]string, 0, len(urlsRaw))
|
||||||
|
|
@ -98,7 +99,7 @@ func (g *BrowserCrawler) Execute(ctx context.Context, args map[string]any) (*Cra
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(urls) == 0 {
|
if len(urls) == 0 {
|
||||||
return nil, fmt.Errorf("at least one URL is required")
|
return nil, errors.New("at least one URL is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
return g.performWebCrawl(ctx, urls)
|
return g.performWebCrawl(ctx, urls)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ package tools
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -84,7 +85,7 @@ func (w *BrowserWebSearch) Schema() map[string]any {
|
||||||
func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) {
|
func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (any, error) {
|
||||||
queriesRaw, ok := args["queries"].([]any)
|
queriesRaw, ok := args["queries"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("queries parameter is required and must be an array of strings")
|
return nil, errors.New("queries parameter is required and must be an array of strings")
|
||||||
}
|
}
|
||||||
|
|
||||||
queries := make([]string, 0, len(queriesRaw))
|
queries := make([]string, 0, len(queriesRaw))
|
||||||
|
|
@ -95,7 +96,7 @@ func (w *BrowserWebSearch) Execute(ctx context.Context, args map[string]any) (an
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(queries) == 0 {
|
if len(queries) == 0 {
|
||||||
return nil, fmt.Errorf("at least one query is required")
|
return nil, errors.New("at least one query is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
maxResults := 5
|
maxResults := 5
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -36,7 +37,7 @@ func (w *WebFetch) Description() string {
|
||||||
return "Crawl and extract text content from web pages"
|
return "Crawl and extract text content from web pages"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *WebFetch) Schema() map[string]any {
|
func (w *WebFetch) Schema() map[string]any {
|
||||||
schemaBytes := []byte(`{
|
schemaBytes := []byte(`{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -61,11 +62,11 @@ func (w *WebFetch) Prompt() string {
|
||||||
func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
func (w *WebFetch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||||
urlRaw, ok := args["url"]
|
urlRaw, ok := args["url"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, "", fmt.Errorf("url parameter is required")
|
return nil, "", errors.New("url parameter is required")
|
||||||
}
|
}
|
||||||
urlStr, ok := urlRaw.(string)
|
urlStr, ok := urlRaw.(string)
|
||||||
if !ok || strings.TrimSpace(urlStr) == "" {
|
if !ok || strings.TrimSpace(urlStr) == "" {
|
||||||
return nil, "", fmt.Errorf("url must be a non-empty string")
|
return nil, "", errors.New("url must be a non-empty string")
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := performWebFetch(ctx, urlStr)
|
result, err := performWebFetch(ctx, urlStr)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -45,7 +46,7 @@ func (w *WebSearch) Prompt() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *WebSearch) Schema() map[string]any {
|
func (w *WebSearch) Schema() map[string]any {
|
||||||
schemaBytes := []byte(`{
|
schemaBytes := []byte(`{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -71,12 +72,12 @@ func (g *WebSearch) Schema() map[string]any {
|
||||||
func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
func (w *WebSearch) Execute(ctx context.Context, args map[string]any) (any, string, error) {
|
||||||
rawQuery, ok := args["query"]
|
rawQuery, ok := args["query"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, "", fmt.Errorf("query parameter is required")
|
return nil, "", errors.New("query parameter is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
queryStr, ok := rawQuery.(string)
|
queryStr, ok := rawQuery.(string)
|
||||||
if !ok || strings.TrimSpace(queryStr) == "" {
|
if !ok || strings.TrimSpace(queryStr) == "" {
|
||||||
return nil, "", fmt.Errorf("query must be a non-empty string")
|
return nil, "", errors.New("query must be a non-empty string")
|
||||||
}
|
}
|
||||||
|
|
||||||
maxResults := 5
|
maxResults := 5
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,12 @@ import (
|
||||||
// Errors wrapping Found should provide additional context, e.g.
|
// Errors wrapping Found should provide additional context, e.g.
|
||||||
// fmt.Errorf("%w: %s", not.Found, key)
|
// fmt.Errorf("%w: %s", not.Found, key)
|
||||||
//
|
//
|
||||||
|
//nolint:staticcheck
|
||||||
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Found.
|
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Found.
|
||||||
var Found = errors.New("not found")
|
var Found = errors.New("not found")
|
||||||
|
|
||||||
// Available is an error that indicates that a value is not available.
|
// Available is an error that indicates that a value is not available.
|
||||||
//
|
//
|
||||||
|
//nolint:staticcheck
|
||||||
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Available.
|
//lint:ignore ST1012 This is a sentinel error intended to be read like not.Available.
|
||||||
var Available = errors.New("not available")
|
var Available = errors.New("not available")
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ package not
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ValidError struct {
|
type ValidError struct {
|
||||||
|
|
@ -44,12 +45,12 @@ func (b Valids) Error() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var result string
|
var sb strings.Builder
|
||||||
for i, err := range b {
|
for i, err := range b {
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
result += "; "
|
sb.WriteString("; ")
|
||||||
}
|
}
|
||||||
result += err.Error()
|
sb.WriteString(err.Error())
|
||||||
}
|
}
|
||||||
return result
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
|
@ -73,7 +74,7 @@ func extractPDFText(data []byte) (string, error) {
|
||||||
if strings.TrimSpace(text) != "" {
|
if strings.TrimSpace(text) != "" {
|
||||||
if textBuilder.Len() > 0 {
|
if textBuilder.Len() > 0 {
|
||||||
textBuilder.WriteString("\n\n--- Page ")
|
textBuilder.WriteString("\n\n--- Page ")
|
||||||
textBuilder.WriteString(fmt.Sprintf("%d", i))
|
textBuilder.WriteString(strconv.Itoa(i))
|
||||||
textBuilder.WriteString(" ---\n")
|
textBuilder.WriteString(" ---\n")
|
||||||
}
|
}
|
||||||
textBuilder.WriteString(text)
|
textBuilder.WriteString(text)
|
||||||
|
|
|
||||||
32
app/ui/ui.go
32
app/ui/ui.go
|
|
@ -194,7 +194,7 @@ func (s *Server) Handler() http.Handler {
|
||||||
log := s.log()
|
log := s.log()
|
||||||
level := slog.LevelInfo
|
level := slog.LevelInfo
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
requestID := fmt.Sprintf("%d", time.Now().UnixNano())
|
requestID := strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
p := recover()
|
p := recover()
|
||||||
|
|
@ -204,7 +204,7 @@ func (s *Server) Handler() http.Handler {
|
||||||
|
|
||||||
// Handle panic with user-friendly error
|
// Handle panic with user-friendly error
|
||||||
if !sw.Written() {
|
if !sw.Written() {
|
||||||
s.handleError(sw, fmt.Errorf("internal server error"))
|
s.handleError(sw, errors.New("internal server error"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -382,7 +382,7 @@ func waitForServer(ctx context.Context) error {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if time.Now().After(timeout) {
|
if time.Now().After(timeout) {
|
||||||
return fmt.Errorf("timeout waiting for Ollama server to be ready")
|
return errors.New("timeout waiting for Ollama server to be ready")
|
||||||
}
|
}
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
}
|
}
|
||||||
|
|
@ -455,7 +455,7 @@ func (s *Server) checkModelUpstream(ctx context.Context, modelName string, timeo
|
||||||
|
|
||||||
digest := resp.Header.Get("ollama-content-digest")
|
digest := resp.Header.Get("ollama-content-digest")
|
||||||
if digest == "" {
|
if digest == "" {
|
||||||
return "", 0, fmt.Errorf("no digest header found")
|
return "", 0, errors.New("no digest header found")
|
||||||
}
|
}
|
||||||
|
|
||||||
var pushTime int64
|
var pushTime int64
|
||||||
|
|
@ -598,12 +598,12 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model == "" {
|
if req.Model == "" {
|
||||||
return fmt.Errorf("empty model")
|
return errors.New("empty model")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't allow empty messages unless forceUpdate is true
|
// Don't allow empty messages unless forceUpdate is true
|
||||||
if req.Prompt == "" && !req.ForceUpdate {
|
if req.Prompt == "" && !req.ForceUpdate {
|
||||||
return fmt.Errorf("empty message")
|
return errors.New("empty message")
|
||||||
}
|
}
|
||||||
|
|
||||||
if createdChat {
|
if createdChat {
|
||||||
|
|
@ -942,7 +942,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||||
} else {
|
} else {
|
||||||
onlyStandalone := true
|
onlyStandalone := true
|
||||||
for _, tc := range res.Message.ToolCalls {
|
for _, tc := range res.Message.ToolCalls {
|
||||||
if !(tc.Function.Name == "web_search" || tc.Function.Name == "web_fetch") {
|
if tc.Function.Name != "web_search" && tc.Function.Name != "web_fetch" {
|
||||||
onlyStandalone = false
|
onlyStandalone = false
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
@ -1194,7 +1194,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
cid := r.PathValue("id")
|
cid := r.PathValue("id")
|
||||||
|
|
||||||
if cid == "" {
|
if cid == "" {
|
||||||
return fmt.Errorf("chat ID is required")
|
return errors.New("chat ID is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
chat, err := s.Store.Chat(cid)
|
chat, err := s.Store.Chat(cid)
|
||||||
|
|
@ -1252,7 +1252,7 @@ func (s *Server) getChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
cid := r.PathValue("id")
|
cid := r.PathValue("id")
|
||||||
if cid == "" {
|
if cid == "" {
|
||||||
return fmt.Errorf("chat ID is required")
|
return errors.New("chat ID is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
|
|
@ -1283,7 +1283,7 @@ func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
cid := r.PathValue("id")
|
cid := r.PathValue("id")
|
||||||
if cid == "" {
|
if cid == "" {
|
||||||
return fmt.Errorf("chat ID is required")
|
return errors.New("chat ID is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the chat exists (no need to load attachments)
|
// Check if the chat exists (no need to load attachments)
|
||||||
|
|
@ -1291,7 +1291,7 @@ func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, not.Found) {
|
if errors.Is(err, not.Found) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
return fmt.Errorf("chat not found")
|
return errors.New("chat not found")
|
||||||
}
|
}
|
||||||
return fmt.Errorf("failed to get chat: %w", err)
|
return fmt.Errorf("failed to get chat: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -1592,7 +1592,7 @@ func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) err
|
||||||
|
|
||||||
func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||||
if r.Method != "POST" {
|
if r.Method != "POST" {
|
||||||
return fmt.Errorf("method not allowed")
|
return errors.New("method not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
var req struct {
|
var req struct {
|
||||||
|
|
@ -1603,7 +1603,7 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Model == "" {
|
if req.Model == "" {
|
||||||
return fmt.Errorf("model is required")
|
return errors.New("model is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second)
|
digest, pushTime, err := s.checkModelUpstream(r.Context(), req.Model, 5*time.Second)
|
||||||
|
|
@ -1730,8 +1730,8 @@ func supportsWebSearchTools(model string) bool {
|
||||||
|
|
||||||
// buildChatRequest converts store.Chat to api.ChatRequest
|
// buildChatRequest converts store.Chat to api.ChatRequest
|
||||||
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, availableTools []map[string]any) (*api.ChatRequest, error) {
|
||||||
var msgs []api.Message
|
msgs := make([]api.Message, len(chat.Messages))
|
||||||
for _, m := range chat.Messages {
|
for i, m := range chat.Messages {
|
||||||
// Skip empty messages if present
|
// Skip empty messages if present
|
||||||
if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 {
|
if m.Content == "" && m.Thinking == "" && len(m.ToolCalls) == 0 && len(m.Attachments) == 0 {
|
||||||
continue
|
continue
|
||||||
|
|
@ -1789,7 +1789,7 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
|
||||||
s.log().Debug("unknown message role", "role", m.Role)
|
s.log().Debug("unknown message role", "role", m.Role)
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs = append(msgs, apiMsg)
|
msgs[i] = apiMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
var thinkValue *api.ThinkValue
|
var thinkValue *api.ThinkValue
|
||||||
|
|
|
||||||
|
|
@ -198,7 +198,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||||
_, err = os.Stat(filepath.Dir(stageFilename))
|
_, err = os.Stat(filepath.Dir(stageFilename))
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil {
|
if err := os.MkdirAll(filepath.Dir(stageFilename), 0o755); err != nil {
|
||||||
return fmt.Errorf("create ollama dir %s: %v", filepath.Dir(stageFilename), err)
|
return fmt.Errorf("create ollama dir %s: %w", filepath.Dir(stageFilename), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -218,7 +218,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||||
|
|
||||||
if err := VerifyDownload(); err != nil {
|
if err := VerifyDownload(); err != nil {
|
||||||
_ = os.Remove(stageFilename)
|
_ = os.Remove(stageFilename)
|
||||||
return fmt.Errorf("%s - %s", resp.Request.URL.String(), err)
|
return fmt.Errorf("%s - %w", resp.Request.URL.String(), err)
|
||||||
}
|
}
|
||||||
UpdateDownloaded = true
|
UpdateDownloaded = true
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ func DoUpgrade(interactive bool) error {
|
||||||
|
|
||||||
bundle := getStagedUpdate()
|
bundle := getStagedUpdate()
|
||||||
if bundle == "" {
|
if bundle == "" {
|
||||||
return fmt.Errorf("failed to lookup downloads")
|
return errors.New("failed to lookup downloads")
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile)
|
slog.Info("starting upgrade", "app", BundlePath, "update", bundle, "pid", os.Getpid(), "log", UpgradeLogFile)
|
||||||
|
|
@ -107,7 +107,7 @@ func DoUpgrade(interactive bool) error {
|
||||||
// Verify old doesn't exist yet
|
// Verify old doesn't exist yet
|
||||||
if _, err := os.Stat(contentsOldName); err == nil {
|
if _, err := os.Stat(contentsOldName); err == nil {
|
||||||
slog.Error("prior upgrade failed", "backup", contentsOldName)
|
slog.Error("prior upgrade failed", "backup", contentsOldName)
|
||||||
return fmt.Errorf("prior upgrade failed - please upgrade manually by installing the bundle")
|
return errors.New("prior upgrade failed - please upgrade manually by installing the bundle")
|
||||||
}
|
}
|
||||||
if err := os.MkdirAll(appBackupDir, 0o755); err != nil {
|
if err := os.MkdirAll(appBackupDir, 0o755); err != nil {
|
||||||
return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err)
|
return fmt.Errorf("unable to create backup dir %s: %w", appBackupDir, err)
|
||||||
|
|
@ -133,7 +133,7 @@ func DoUpgrade(interactive bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !chownWithAuthorization(u.Username) {
|
if !chownWithAuthorization(u.Username) {
|
||||||
return fmt.Errorf("unable to change permissions to complete upgrade")
|
return errors.New("unable to change permissions to complete upgrade")
|
||||||
}
|
}
|
||||||
if err := os.Rename(BundlePath, appBackup); err != nil {
|
if err := os.Rename(BundlePath, appBackup); err != nil {
|
||||||
return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err)
|
return fmt.Errorf("unable to perform upgrade - failed to stage old version: %w", err)
|
||||||
|
|
@ -264,7 +264,7 @@ func DoPostUpgradeCleanup() error {
|
||||||
func verifyDownload() error {
|
func verifyDownload() error {
|
||||||
bundle := getStagedUpdate()
|
bundle := getStagedUpdate()
|
||||||
if bundle == "" {
|
if bundle == "" {
|
||||||
return fmt.Errorf("failed to lookup downloads")
|
return errors.New("failed to lookup downloads")
|
||||||
}
|
}
|
||||||
slog.Debug("verifying update", "bundle", bundle)
|
slog.Debug("verifying update", "bundle", bundle)
|
||||||
|
|
||||||
|
|
@ -338,7 +338,7 @@ func verifyDownload() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil {
|
if err := verifyExtractedBundle(filepath.Join(dir, "Ollama.app")); err != nil {
|
||||||
return fmt.Errorf("signature verification failed: %s", err)
|
return fmt.Errorf("signature verification failed: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -347,11 +347,11 @@ func verifyDownload() error {
|
||||||
func DoUpgradeAtStartup() error {
|
func DoUpgradeAtStartup() error {
|
||||||
bundle := getStagedUpdate()
|
bundle := getStagedUpdate()
|
||||||
if bundle == "" {
|
if bundle == "" {
|
||||||
return fmt.Errorf("failed to lookup downloads")
|
return errors.New("failed to lookup downloads")
|
||||||
}
|
}
|
||||||
|
|
||||||
if BundlePath == "" {
|
if BundlePath == "" {
|
||||||
return fmt.Errorf("unable to upgrade at startup, app in development mode")
|
return errors.New("unable to upgrade at startup, app in development mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
// [Re]verify before proceeding
|
// [Re]verify before proceeding
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,7 @@ func TestIsNewReleaseAvailable(t *testing.T) {
|
||||||
var server *httptest.Server
|
var server *httptest.Server
|
||||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/update.json" {
|
if r.URL.Path == "/update.json" {
|
||||||
w.Write([]byte(
|
fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
|
||||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
|
||||||
server.URL+"/9.9.9/"+Installer)))
|
|
||||||
// TODO - wire up the redirects to mimic real behavior
|
// TODO - wire up the redirects to mimic real behavior
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("unexpected request", "url", r.URL)
|
slog.Debug("unexpected request", "url", r.URL)
|
||||||
|
|
@ -67,17 +65,16 @@ func TestBackgoundChecker(t *testing.T) {
|
||||||
|
|
||||||
var server *httptest.Server
|
var server *httptest.Server
|
||||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/update.json" {
|
switch r.URL.Path {
|
||||||
w.Write([]byte(
|
case "/update.json":
|
||||||
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
|
fmt.Fprintf(w, `{"version": "9.9.9", "url": "%s"}`, server.URL+"/9.9.9/"+Installer)
|
||||||
server.URL+"/9.9.9/"+Installer)))
|
|
||||||
// TODO - wire up the redirects to mimic real behavior
|
// TODO - wire up the redirects to mimic real behavior
|
||||||
} else if r.URL.Path == "/9.9.9/"+Installer {
|
case "/9.9.9/" + Installer:
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
zw := zip.NewWriter(buf)
|
zw := zip.NewWriter(buf)
|
||||||
zw.Close()
|
zw.Close()
|
||||||
io.Copy(w, buf)
|
io.Copy(w, buf)
|
||||||
} else {
|
default:
|
||||||
slog.Debug("unexpected request", "url", r.URL)
|
slog.Debug("unexpected request", "url", r.URL)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ func BenchmarkChat(fOpt flagOptions) error {
|
||||||
|
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
for range *fOpt.epochs {
|
for range *fOpt.epochs {
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]any)
|
||||||
if *fOpt.maxTokens > 0 {
|
if *fOpt.maxTokens > 0 {
|
||||||
options["num_predict"] = *fOpt.maxTokens
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -442,7 +442,7 @@ func TestReadImage_FileNotFound(t *testing.T) {
|
||||||
func TestOptionsMapCreation(t *testing.T) {
|
func TestOptionsMapCreation(t *testing.T) {
|
||||||
fOpt := createTestFlagOptions()
|
fOpt := createTestFlagOptions()
|
||||||
|
|
||||||
options := make(map[string]interface{})
|
options := make(map[string]any)
|
||||||
if *fOpt.maxTokens > 0 {
|
if *fOpt.maxTokens > 0 {
|
||||||
options["num_predict"] = *fOpt.maxTokens
|
options["num_predict"] = *fOpt.maxTokens
|
||||||
}
|
}
|
||||||
|
|
|
||||||
35
cmd/cmd.go
35
cmd/cmd.go
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"maps"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -203,7 +204,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
if err := client.Create(cmd.Context(), req, fn); err != nil {
|
if err := client.Create(cmd.Context(), req, fn); err != nil {
|
||||||
if strings.Contains(err.Error(), "path or Modelfile are required") {
|
if strings.Contains(err.Error(), "path or Modelfile are required") {
|
||||||
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client")
|
return errors.New("the ollama server must be updated to use `ollama create` with this client")
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -990,7 +991,7 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||||
var v string
|
var v string
|
||||||
switch vData := resp.ModelInfo[k].(type) {
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
case bool:
|
case bool:
|
||||||
v = fmt.Sprintf("%t", vData)
|
v = strconv.FormatBool(vData)
|
||||||
case string:
|
case string:
|
||||||
v = vData
|
v = vData
|
||||||
case float64:
|
case float64:
|
||||||
|
|
@ -1204,9 +1205,7 @@ func (r runOptions) Copy() runOptions {
|
||||||
var opts map[string]any
|
var opts map[string]any
|
||||||
if r.Options != nil {
|
if r.Options != nil {
|
||||||
opts = make(map[string]any, len(r.Options))
|
opts = make(map[string]any, len(r.Options))
|
||||||
for k, v := range r.Options {
|
maps.Copy(opts, r.Options)
|
||||||
opts[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var think *api.ThinkValue
|
var think *api.ThinkValue
|
||||||
|
|
@ -1330,12 +1329,12 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var state *displayResponseState = &displayResponseState{}
|
state := &displayResponseState{}
|
||||||
var thinkingContent strings.Builder
|
var thinkingContent strings.Builder
|
||||||
var latest api.ChatResponse
|
var latest api.ChatResponse
|
||||||
var fullResponse strings.Builder
|
var fullResponse strings.Builder
|
||||||
var thinkTagOpened bool = false
|
thinkTagOpened := false
|
||||||
var thinkTagClosed bool = false
|
thinkTagClosed := false
|
||||||
|
|
||||||
role := "assistant"
|
role := "assistant"
|
||||||
|
|
||||||
|
|
@ -1463,10 +1462,10 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||||
cancel()
|
cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var state *displayResponseState = &displayResponseState{}
|
state := &displayResponseState{}
|
||||||
var thinkingContent strings.Builder
|
var thinkingContent strings.Builder
|
||||||
var thinkTagOpened bool = false
|
thinkTagOpened := false
|
||||||
var thinkTagClosed bool = false
|
thinkTagClosed := false
|
||||||
|
|
||||||
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
|
||||||
|
|
@ -1634,7 +1633,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := client.Heartbeat(cmd.Context()); err != nil {
|
if err := client.Heartbeat(cmd.Context()); err != nil {
|
||||||
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
|
if !strings.Contains(err.Error(), " refused") && !strings.Contains(err.Error(), "could not connect") {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := startApp(cmd.Context(), client); err != nil {
|
if err := startApp(cmd.Context(), client); err != nil {
|
||||||
|
|
@ -1952,13 +1951,13 @@ func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicit
|
||||||
}
|
}
|
||||||
|
|
||||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||||
out := ""
|
var sb strings.Builder
|
||||||
formatExplanation := ""
|
formatExplanation := ""
|
||||||
formatValues := ""
|
formatValues := ""
|
||||||
if !plainText {
|
if !plainText {
|
||||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||||
formatValues = readline.ColorDefault
|
formatValues = readline.ColorDefault
|
||||||
out += formatExplanation
|
sb.WriteString(formatExplanation)
|
||||||
}
|
}
|
||||||
for i, toolCall := range toolCalls {
|
for i, toolCall := range toolCalls {
|
||||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||||
|
|
@ -1966,13 +1965,13 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
if i > 0 {
|
if i > 0 {
|
||||||
out += "\n"
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
||||||
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
fmt.Fprintf(&sb, " Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||||
}
|
}
|
||||||
if !plainText {
|
if !plainText {
|
||||||
out += readline.ColorDefault
|
sb.WriteString(readline.ColorDefault)
|
||||||
}
|
}
|
||||||
return out
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package cmd
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
@ -307,7 +308,7 @@ func TestDeleteHandler(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
errPayload := `{"error":"model '%s' not found"}`
|
errPayload := `{"error":"model '%s' not found"}`
|
||||||
w.Write([]byte(fmt.Sprintf(errPayload, req.Name)))
|
fmt.Fprintf(w, errPayload, req.Name)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -761,8 +762,8 @@ func TestGetModelfileName(t *testing.T) {
|
||||||
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
|
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.expectedErr != os.ErrNotExist {
|
if !errors.Is(tt.expectedErr, os.ErrNotExist) {
|
||||||
if actualErr != tt.expectedErr {
|
if !errors.Is(actualErr, tt.expectedErr) {
|
||||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -924,10 +925,8 @@ func TestPushHandler(t *testing.T) {
|
||||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -1014,10 +1013,8 @@ func TestListHandler(t *testing.T) {
|
||||||
if got := string(output); got != tt.expectedOutput {
|
if got := string(output); got != tt.expectedOutput {
|
||||||
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
t.Errorf("expected output:\n%s\ngot:\n%s", tt.expectedOutput, got)
|
||||||
}
|
}
|
||||||
} else {
|
} else if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -1322,8 +1319,8 @@ func TestRunOptions_Copy(t *testing.T) {
|
||||||
// Test 2: Verify all fields are copied correctly
|
// Test 2: Verify all fields are copied correctly
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
got interface{}
|
got any
|
||||||
want interface{}
|
want any
|
||||||
}{
|
}{
|
||||||
{"Model", copied.Model, original.Model},
|
{"Model", copied.Model, original.Model},
|
||||||
{"ParentModel", copied.ParentModel, original.ParentModel},
|
{"ParentModel", copied.ParentModel, original.ParentModel},
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var multiline MultilineState
|
var multiline MultilineState
|
||||||
var thinkExplicitlySet bool = opts.Think != nil
|
thinkExplicitlySet := opts.Think != nil
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
|
|
@ -410,7 +410,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
if resp.Parameters == "" {
|
if resp.Parameters == "" {
|
||||||
fmt.Println(" No additional parameters were specified for this model.")
|
fmt.Println(" No additional parameters were specified for this model.")
|
||||||
} else {
|
} else {
|
||||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
for l := range strings.SplitSeq(resp.Parameters, "\n") {
|
||||||
fmt.Printf(" %s\n", l)
|
fmt.Printf(" %s\n", l)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -576,9 +576,8 @@ func extractFileNames(input string) []string {
|
||||||
|
|
||||||
func extractFileData(input string) (string, []api.ImageData, error) {
|
func extractFileData(input string) (string, []api.ImageData, error) {
|
||||||
filePaths := extractFileNames(input)
|
filePaths := extractFileNames(input)
|
||||||
var imgs []api.ImageData
|
imgs := make([]api.ImageData, len(filePaths))
|
||||||
|
for i, fp := range filePaths {
|
||||||
for _, fp := range filePaths {
|
|
||||||
nfp := normalizeFilePath(fp)
|
nfp := normalizeFilePath(fp)
|
||||||
data, err := getImageData(nfp)
|
data, err := getImageData(nfp)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
|
@ -591,7 +590,7 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||||
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||||
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||||
input = strings.ReplaceAll(input, fp, "")
|
input = strings.ReplaceAll(input, fp, "")
|
||||||
imgs = append(imgs, data)
|
imgs[i] = data
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(input), imgs, nil
|
return strings.TrimSpace(input), imgs, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,10 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
||||||
"general.file_type": uint32(1),
|
"general.file_type": uint32(1),
|
||||||
"general.quantization_version": uint32(2),
|
"general.quantization_version": uint32(2),
|
||||||
"tokenizer.ggml.pre": t.Pre,
|
"tokenizer.ggml.pre": t.Pre,
|
||||||
"tokenizer.ggml.model": t.Vocabulary.Model,
|
"tokenizer.ggml.model": t.Model,
|
||||||
"tokenizer.ggml.tokens": t.Vocabulary.Tokens,
|
"tokenizer.ggml.tokens": t.Tokens,
|
||||||
"tokenizer.ggml.scores": t.Vocabulary.Scores,
|
"tokenizer.ggml.scores": t.Scores,
|
||||||
"tokenizer.ggml.token_type": t.Vocabulary.Types,
|
"tokenizer.ggml.token_type": t.Types,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(t.Merges) > 0 {
|
if len(t.Merges) > 0 {
|
||||||
|
|
@ -231,20 +231,20 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case vocabSize == 0:
|
case vocabSize == 0:
|
||||||
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Tokens))
|
||||||
case vocabSize > len(t.Vocabulary.Tokens):
|
case vocabSize > len(t.Tokens):
|
||||||
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Tokens))
|
||||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
for i := range vocabSize - len(t.Tokens) {
|
||||||
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
t.Tokens = append(t.Tokens, fmt.Sprintf("[PAD%d]", i))
|
||||||
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
t.Scores = append(t.Scores, -1)
|
||||||
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
t.Types = append(t.Types, tokenTypeUserDefined)
|
||||||
}
|
}
|
||||||
case vocabSize < len(t.Vocabulary.Tokens):
|
case vocabSize < len(t.Tokens):
|
||||||
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Tokens))
|
||||||
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
p.VocabSize = uint32(len(t.Tokens))
|
||||||
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
p.TextModel.VocabSize = uint32(len(t.Tokens))
|
||||||
default:
|
default:
|
||||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary", "size", len(t.Tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
ts, err := parseTensors(fsys, strings.NewReplacer(conv.Replacements()...))
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if slices.Contains([]string{
|
if slices.Contains([]string{
|
||||||
"embeddings.position_ids",
|
"embeddings.position_ids",
|
||||||
|
|
|
||||||
|
|
@ -44,14 +44,14 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, len(ts))
|
||||||
for _, t := range ts {
|
for i, t := range ts {
|
||||||
out = append(out, &ggml.Tensor{
|
out[i] = &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
WriterTo: t,
|
WriterTo: t,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
||||||
|
|
@ -43,18 +43,18 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, len(ts))
|
||||||
for _, t := range ts {
|
for i, t := range ts {
|
||||||
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||||
t.SetRepacker(p.addOne)
|
t.SetRepacker(p.addOne)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, &ggml.Tensor{
|
out[i] = &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
WriterTo: t,
|
WriterTo: t,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, len(ts))
|
||||||
for _, t := range ts {
|
for i, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
|
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
|
||||||
|
|
@ -31,12 +31,12 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, &ggml.Tensor{
|
out[i] = &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
WriterTo: t,
|
WriterTo: t,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
for name, mxfp4 := range mxfp4s {
|
for name, mxfp4 := range mxfp4s {
|
||||||
dims := mxfp4.blocks.Shape()
|
dims := mxfp4.blocks.Shape()
|
||||||
if !strings.HasSuffix(name, ".weight") {
|
if !strings.HasSuffix(name, ".weight") {
|
||||||
name = name + ".weight"
|
name += ".weight"
|
||||||
}
|
}
|
||||||
if strings.Contains(name, "ffn_down_exps") {
|
if strings.Contains(name, "ffn_down_exps") {
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, 0, len(ts)+1)
|
||||||
|
|
||||||
if p.RopeScaling.factors != nil {
|
if p.RopeScaling.factors != nil {
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
|
|
@ -176,9 +176,9 @@ func (p *llamaModel) Replacements() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
var dims []int
|
dims := make([]int, len(shape))
|
||||||
for _, dim := range shape {
|
for i, dim := range shape {
|
||||||
dims = append(dims, int(dim))
|
dims[i] = int(dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
var heads uint32
|
var heads uint32
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, len(ts))
|
||||||
for _, t := range ts {
|
for i, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
|
(strings.HasSuffix(t.Name(), "weight.lora_b") && shape[0] < shape[1]) {
|
||||||
|
|
@ -41,12 +41,12 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, &ggml.Tensor{
|
out[i] = &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: shape,
|
Shape: shape,
|
||||||
WriterTo: t,
|
WriterTo: t,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, len(ts))
|
||||||
|
for i, t := range ts {
|
||||||
for _, t := range ts {
|
|
||||||
if !strings.HasPrefix(t.Name(), "v.") {
|
if !strings.HasPrefix(t.Name(), "v.") {
|
||||||
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||||
|
|
@ -100,12 +99,12 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, &ggml.Tensor{
|
out[i] = &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
WriterTo: t,
|
WriterTo: t,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
@ -145,9 +144,9 @@ func (p *mistral3Model) Replacements() []string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
var dims []int
|
dims := make([]int, len(shape))
|
||||||
for _, dim := range shape {
|
for i, dim := range shape {
|
||||||
dims = append(dims, int(dim))
|
dims[i] = int(dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
var heads uint32
|
var heads uint32
|
||||||
|
|
|
||||||
|
|
@ -49,20 +49,20 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []*ggml.Tensor
|
out := make([]*ggml.Tensor, len(ts))
|
||||||
for _, t := range ts {
|
for i, t := range ts {
|
||||||
out = append(out, &ggml.Tensor{
|
out[i] = &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
WriterTo: t,
|
WriterTo: t,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *qwen2Model) Replacements() []string {
|
func (q *qwen2Model) Replacements() []string {
|
||||||
return []string{
|
return []string{
|
||||||
"lm_head", "output",
|
"lm_head", "output",
|
||||||
"model.embed_tokens", "token_embd",
|
"model.embed_tokens", "token_embd",
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,9 @@ func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *qwen25VLModel) Replacements() []string {
|
func (q *qwen25VLModel) Replacements() []string {
|
||||||
return append(
|
return append(
|
||||||
p.qwen2Model.Replacements(),
|
q.qwen2Model.Replacements(),
|
||||||
"visual", "v",
|
"visual", "v",
|
||||||
"blocks", "blk",
|
"blocks", "blk",
|
||||||
"attn.proj", "attn_out",
|
"attn.proj", "attn_out",
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,6 @@ func (t torch) Clone() Tensor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
func (t torch) WriteTo(w io.Writer) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||||
content string
|
content string
|
||||||
}
|
}
|
||||||
|
|
||||||
var ts []t
|
ts := make([]t, 0, len(atm))
|
||||||
for content, id := range atm {
|
for content, id := range atm {
|
||||||
ts = append(ts, t{id, content})
|
ts = append(ts, t{id, content})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -300,9 +300,9 @@ func (s Tensors) Items(prefix ...string) []*Tensor {
|
||||||
return items
|
return items
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensors) GroupLayers() map[string]Layer {
|
func (s Tensors) GroupLayers() map[string]Layer {
|
||||||
layers := make(map[string]Layer)
|
layers := make(map[string]Layer)
|
||||||
for _, t := range ts.items {
|
for _, t := range s.items {
|
||||||
parts := strings.Split(t.Name, ".")
|
parts := strings.Split(t.Name, ".")
|
||||||
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
if index := slices.IndexFunc(parts, func(s string) bool { return s == "blk" || s == "mm" }); index != -1 {
|
||||||
if len(parts) > index+2 {
|
if len(parts) > index+2 {
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
@ -225,7 +226,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||||
Name: name,
|
Name: name,
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
Shape: shape[:],
|
Shape: shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
llm.tensors = append(llm.tensors, &tensor)
|
llm.tensors = append(llm.tensors, &tensor)
|
||||||
|
|
@ -511,7 +512,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||||
arch := kv.String("general.architecture")
|
arch := kv.String("general.architecture")
|
||||||
if arch == "" {
|
if arch == "" {
|
||||||
return fmt.Errorf("architecture not set")
|
return errors.New("architecture not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -136,8 +136,8 @@ func (t FileType) Value() uint32 {
|
||||||
return uint32(t)
|
return uint32(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ftype FileType) ToTensorType() TensorType {
|
func (t FileType) ToTensorType() TensorType {
|
||||||
switch ftype {
|
switch t {
|
||||||
case FileTypeF32:
|
case FileTypeF32:
|
||||||
return TensorTypeF32
|
return TensorTypeF32
|
||||||
case FileTypeF16:
|
case FileTypeF16:
|
||||||
|
|
@ -177,7 +177,7 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||||
case fileTypeMXFP4:
|
case fileTypeMXFP4:
|
||||||
return TensorTypeMXFP4
|
return TensorTypeMXFP4
|
||||||
default:
|
default:
|
||||||
slog.Warn("unsupported file type", "type", ftype)
|
slog.Warn("unsupported file type", "type", t)
|
||||||
return 0 // F32
|
return 0 // F32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ type KeyValue struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KeyValue) Valid() bool {
|
func (kv KeyValue) Valid() bool {
|
||||||
return kv.Key != "" && kv.Value.value != nil
|
return kv.Key != "" && kv.value != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Value struct {
|
type Value struct {
|
||||||
|
|
|
||||||
|
|
@ -200,9 +200,7 @@ func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||||
before := raw[:channelIndex]
|
before := raw[:channelIndex]
|
||||||
after := raw[channelIndex+len("<|channel|>"):]
|
after := raw[channelIndex+len("<|channel|>"):]
|
||||||
// the channel name is `after` all the way up to the first (if any) whitespace character
|
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||||
idx := strings.IndexFunc(after, func(r rune) bool {
|
idx := strings.IndexFunc(after, unicode.IsSpace)
|
||||||
return unicode.IsSpace(r)
|
|
||||||
})
|
|
||||||
if idx == -1 {
|
if idx == -1 {
|
||||||
idx = len(after)
|
idx = len(after)
|
||||||
}
|
}
|
||||||
|
|
@ -319,11 +317,12 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||||
}
|
}
|
||||||
case HarmonyEventContentEmitted:
|
case HarmonyEventContentEmitted:
|
||||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||||
if h.state == harmonyMessageState_Normal {
|
switch h.state {
|
||||||
|
case harmonyMessageState_Normal:
|
||||||
contentSb.WriteString(event.Content)
|
contentSb.WriteString(event.Content)
|
||||||
} else if h.state == harmonyMessageState_Thinking {
|
case harmonyMessageState_Thinking:
|
||||||
thinkingSb.WriteString(event.Content)
|
thinkingSb.WriteString(event.Content)
|
||||||
} else if h.state == harmonyMessageState_ToolCalling {
|
case harmonyMessageState_ToolCalling:
|
||||||
toolContentSb.WriteString(event.Content)
|
toolContentSb.WriteString(event.Content)
|
||||||
}
|
}
|
||||||
case HarmonyEventMessageEnd:
|
case HarmonyEventMessageEnd:
|
||||||
|
|
|
||||||
|
|
@ -263,9 +263,9 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||||
cparams.use_mmap = C.bool(params.UseMmap)
|
cparams.use_mmap = C.bool(params.UseMmap)
|
||||||
cparams.vocab_only = C.bool(params.VocabOnly)
|
cparams.vocab_only = C.bool(params.VocabOnly)
|
||||||
|
|
||||||
var devices []C.ggml_backend_dev_t
|
devices := make([]C.ggml_backend_dev_t, len(params.Devices))
|
||||||
for _, llamaID := range params.Devices {
|
for i, llamaID := range params.Devices {
|
||||||
devices = append(devices, C.ggml_backend_dev_get(C.size_t(llamaID)))
|
devices[i] = C.ggml_backend_dev_get(C.size_t(llamaID))
|
||||||
}
|
}
|
||||||
if len(devices) > 0 {
|
if len(devices) > 0 {
|
||||||
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
|
devices = append(devices, C.ggml_backend_dev_t(C.NULL))
|
||||||
|
|
|
||||||
|
|
@ -250,7 +250,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
||||||
if s.status != nil && s.status.LastErrMsg != "" {
|
if s.status != nil && s.status.LastErrMsg != "" {
|
||||||
msg = s.status.LastErrMsg
|
msg = s.status.LastErrMsg
|
||||||
}
|
}
|
||||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
err := fmt.Errorf("error starting runner: %w %s", err, msg)
|
||||||
if llamaModel != nil {
|
if llamaModel != nil {
|
||||||
llama.FreeModel(llamaModel)
|
llama.FreeModel(llamaModel)
|
||||||
}
|
}
|
||||||
|
|
@ -846,14 +846,7 @@ nextOperation:
|
||||||
func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
|
func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
|
||||||
devices := []ml.DeviceID{}
|
devices := []ml.DeviceID{}
|
||||||
for _, layer := range gpuLayers {
|
for _, layer := range gpuLayers {
|
||||||
new := true
|
if !slices.Contains(devices, layer.DeviceID) {
|
||||||
for _, ID := range devices {
|
|
||||||
if layer.DeviceID == ID {
|
|
||||||
new = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if new {
|
|
||||||
devices = append(devices, layer.DeviceID)
|
devices = append(devices, layer.DeviceID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -989,13 +982,11 @@ nextLayer:
|
||||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
||||||
return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
|
return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
|
||||||
}
|
}
|
||||||
} else {
|
} else if vramSize > systemInfo.TotalMemory {
|
||||||
if vramSize > systemInfo.TotalMemory {
|
// disable partial offloading when model is greater than total system memory as this
|
||||||
// disable partial offloading when model is greater than total system memory as this
|
// can lead to locking up the system
|
||||||
// can lead to locking up the system
|
s.options.NumGPU = 0
|
||||||
s.options.NumGPU = 0
|
gpuLayers = ml.GPULayersList{}
|
||||||
gpuLayers = ml.GPULayersList{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if gpuLayers.Sum() == 0 {
|
if gpuLayers.Sum() == 0 {
|
||||||
|
|
@ -1218,7 +1209,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
|
return ServerStatusError, fmt.Errorf("error creating GET request: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
|
@ -1481,7 +1472,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
// User provided a JSON schema
|
// User provided a JSON schema
|
||||||
g := llama.SchemaToGrammar(req.Format)
|
g := llama.SchemaToGrammar(req.Format)
|
||||||
if g == nil {
|
if g == nil {
|
||||||
return fmt.Errorf("invalid JSON schema in format")
|
return errors.New("invalid JSON schema in format")
|
||||||
}
|
}
|
||||||
req.Grammar = string(g)
|
req.Grammar = string(g)
|
||||||
}
|
}
|
||||||
|
|
@ -1521,13 +1512,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
if err := enc.Encode(req); err != nil {
|
if err := enc.Encode(req); err != nil {
|
||||||
return fmt.Errorf("failed to marshal data: %v", err)
|
return fmt.Errorf("failed to marshal data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||||
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating POST request: %v", err)
|
return fmt.Errorf("error creating POST request: %w", err)
|
||||||
}
|
}
|
||||||
serverReq.Header.Set("Content-Type", "application/json")
|
serverReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
|
@ -1576,7 +1567,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
|
|
||||||
var c CompletionResponse
|
var c CompletionResponse
|
||||||
if err := json.Unmarshal(evt, &c); err != nil {
|
if err := json.Unmarshal(evt, &c); err != nil {
|
||||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
return fmt.Errorf("error unmarshalling llm prediction response: %w", err)
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case strings.TrimSpace(c.Content) == lastToken:
|
case strings.TrimSpace(c.Content) == lastToken:
|
||||||
|
|
@ -1618,7 +1609,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
return fmt.Errorf("an error was encountered while running the model: %s", msg)
|
return fmt.Errorf("an error was encountered while running the model: %s", msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("error reading llm response: %v", err)
|
return fmt.Errorf("error reading llm response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -1693,7 +1684,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
|
||||||
defer s.llamaModelLock.Unlock()
|
defer s.llamaModelLock.Unlock()
|
||||||
|
|
||||||
if s.llamaModel == nil {
|
if s.llamaModel == nil {
|
||||||
return nil, fmt.Errorf("no tokenizer configured")
|
return nil, errors.New("no tokenizer configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.llamaModel.Tokenize(content, false, true)
|
return s.llamaModel.Tokenize(content, false, true)
|
||||||
|
|
@ -1718,15 +1709,15 @@ func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, err
|
||||||
defer s.llamaModelLock.Unlock()
|
defer s.llamaModelLock.Unlock()
|
||||||
|
|
||||||
if s.llamaModel == nil {
|
if s.llamaModel == nil {
|
||||||
return "", fmt.Errorf("no tokenizer configured")
|
return "", errors.New("no tokenizer configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp string
|
var sb strings.Builder
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
resp += s.llamaModel.TokenToPiece(token)
|
sb.WriteString(s.llamaModel.TokenToPiece(token))
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||||
|
|
|
||||||
|
|
@ -209,7 +209,7 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
|
gpuLayers, err := s.createLayout(systemInfo, tt.gpus, s.mem, tt.requireFull, 0)
|
||||||
if err != tt.expectedErr {
|
if !errors.Is(err, tt.expectedErr) {
|
||||||
t.Fatalf("fitGPU returned error: %v", err)
|
t.Fatalf("fitGPU returned error: %v", err)
|
||||||
}
|
}
|
||||||
if gpuLayers.Hash() != tt.expected.Hash() {
|
if gpuLayers.Hash() != tt.expected.Hash() {
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -98,7 +98,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -123,7 +123,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
@ -150,7 +150,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -164,7 +164,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
_, err = fmt.Fprintf(w.ResponseWriter, "data: %s\n\n", d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
@ -189,7 +189,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
@ -214,7 +214,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
@ -240,7 +240,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
@ -265,7 +265,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||||
code := w.ResponseWriter.Status()
|
code := w.Status()
|
||||||
if code != http.StatusOK {
|
if code != http.StatusOK {
|
||||||
return w.writeError(data)
|
return w.writeError(data)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) {
|
||||||
|
|
||||||
switch tc.expectType {
|
switch tc.expectType {
|
||||||
case "array":
|
case "array":
|
||||||
if _, ok := result.Data[0].Embedding.([]interface{}); !ok {
|
if _, ok := result.Data[0].Embedding.([]any); !ok {
|
||||||
t.Errorf("expected array, got %T", result.Data[0].Embedding)
|
t.Errorf("expected array, got %T", result.Data[0].Embedding)
|
||||||
}
|
}
|
||||||
case "string":
|
case "string":
|
||||||
|
|
@ -210,10 +210,8 @@ func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) {
|
||||||
if !strings.Contains(errResp.Error.Message, "encoding_format") {
|
if !strings.Contains(errResp.Error.Message, "encoding_format") {
|
||||||
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
|
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
|
||||||
}
|
}
|
||||||
} else {
|
} else if resp.Code != http.StatusOK {
|
||||||
if resp.Code != http.StatusOK {
|
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||||
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -845,19 +845,17 @@ func TestListMiddleware(t *testing.T) {
|
||||||
resp := httptest.NewRecorder()
|
resp := httptest.NewRecorder()
|
||||||
router.ServeHTTP(resp, req)
|
router.ServeHTTP(resp, req)
|
||||||
|
|
||||||
var expected, actual map[string]any
|
var want, got map[string]any
|
||||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
if err := json.Unmarshal([]byte(tc.resp), &want); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
if err := json.Unmarshal(resp.Body.Bytes(), &got); err != nil {
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(expected, actual) {
|
if diff := cmp.Diff(want, got); diff != "" {
|
||||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
t.Errorf("response does not match (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
@ -92,7 +93,7 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||||
return backend(modelPath, params)
|
return backend(modelPath, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported backend")
|
return nil, errors.New("unsupported backend")
|
||||||
}
|
}
|
||||||
|
|
||||||
type Context interface {
|
type Context interface {
|
||||||
|
|
|
||||||
|
|
@ -178,14 +178,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||||
requiredMemory.CPU.Cache = make([]uint64, blocks+1)
|
requiredMemory.CPU.Cache = make([]uint64, blocks+1)
|
||||||
|
|
||||||
// create list of buffer types for each gpu
|
// create list of buffer types for each gpu
|
||||||
var gpuDeviceBufferTypes []deviceBufferType
|
gpuDeviceBufferTypes := make([]deviceBufferType, len(gpus))
|
||||||
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
|
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
|
||||||
for i, d := range gpus {
|
for i, d := range gpus {
|
||||||
bt := C.ggml_backend_dev_buffer_type(d)
|
bt := C.ggml_backend_dev_buffer_type(d)
|
||||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
gpuDeviceBufferTypes[i] = deviceBufferType{
|
||||||
d: d,
|
d: d,
|
||||||
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
|
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
|
||||||
})
|
}
|
||||||
|
|
||||||
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||||
|
|
@ -354,8 +354,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||||
deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t)
|
deviceBufferTypes := make(map[C.ggml_backend_dev_t]C.ggml_backend_buffer_type_t)
|
||||||
|
|
||||||
// create backends and buffer types used for the compute graph scheduler
|
// create backends and buffer types used for the compute graph scheduler
|
||||||
var schedBackends []C.ggml_backend_t
|
schedBackends := make([]C.ggml_backend_t, 0, len(cpus)+len(accels)+len(gpus))
|
||||||
var schedBufts []C.ggml_backend_buffer_type_t
|
schedBufts := make([]C.ggml_backend_buffer_type_t, 0, len(cpus)+len(accels)+len(gpus))
|
||||||
for _, d := range append(gpus, append(accels, cpus...)...) {
|
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||||
b := backends[d]
|
b := backends[d]
|
||||||
bt := C.ggml_backend_get_default_buffer_type(b)
|
bt := C.ggml_backend_get_default_buffer_type(b)
|
||||||
|
|
|
||||||
38
ml/device.go
38
ml/device.go
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/maphash"
|
"hash/maphash"
|
||||||
"io"
|
"io"
|
||||||
|
|
@ -218,7 +219,7 @@ type BackendMemory struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m BackendMemory) LogValue() slog.Value {
|
func (m BackendMemory) LogValue() slog.Value {
|
||||||
var attrs []slog.Attr
|
attrs := make([]slog.Attr, 0, 2+len(m.GPUs))
|
||||||
if m.InputWeights != 0 {
|
if m.InputWeights != 0 {
|
||||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||||
}
|
}
|
||||||
|
|
@ -414,14 +415,7 @@ func LibraryPaths(l []DeviceInfo) []string {
|
||||||
gpuLibs := []string{LibOllamaPath}
|
gpuLibs := []string{LibOllamaPath}
|
||||||
for _, gpu := range l {
|
for _, gpu := range l {
|
||||||
for _, dir := range gpu.LibraryPath {
|
for _, dir := range gpu.LibraryPath {
|
||||||
needed := true
|
if !slices.Contains(gpuLibs, dir) {
|
||||||
for _, existing := range gpuLibs {
|
|
||||||
if dir == existing {
|
|
||||||
needed = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if needed {
|
|
||||||
gpuLibs = append(gpuLibs, dir)
|
gpuLibs = append(gpuLibs, dir)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -437,15 +431,15 @@ const (
|
||||||
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
DuplicateDevice // The same physical device but different library/backend (overlapping device)
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
func (d DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||||
if a.PCIID != b.PCIID {
|
if d.PCIID != b.PCIID {
|
||||||
return UniqueDevice
|
return UniqueDevice
|
||||||
}
|
}
|
||||||
// If PCIID is empty, we have to use ID + library for uniqueness
|
// If PCIID is empty, we have to use ID + library for uniqueness
|
||||||
if a.PCIID == "" && a.DeviceID != b.DeviceID {
|
if d.PCIID == "" && d.DeviceID != b.DeviceID {
|
||||||
return UniqueDevice
|
return UniqueDevice
|
||||||
}
|
}
|
||||||
if a.Library == b.Library {
|
if d.Library == b.Library {
|
||||||
return SameBackendDevice
|
return SameBackendDevice
|
||||||
}
|
}
|
||||||
return DuplicateDevice
|
return DuplicateDevice
|
||||||
|
|
@ -453,8 +447,8 @@ func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
|
||||||
|
|
||||||
// For a SameBackendDevice, return true if b is better than a
|
// For a SameBackendDevice, return true if b is better than a
|
||||||
// e.g. newer GPU library version
|
// e.g. newer GPU library version
|
||||||
func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
func (d DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||||
aLib := a.LibraryPath[len(a.LibraryPath)-1]
|
aLib := d.LibraryPath[len(d.LibraryPath)-1]
|
||||||
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
bLib := b.LibraryPath[len(b.LibraryPath)-1]
|
||||||
if aLib == bLib {
|
if aLib == bLib {
|
||||||
return false
|
return false
|
||||||
|
|
@ -481,7 +475,7 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||||
for _, gpu := range l {
|
for _, gpu := range l {
|
||||||
supportsFA := gpu.Library == "cpu" ||
|
supportsFA := gpu.Library == "cpu" ||
|
||||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && (gpu.ComputeMajor != 7 || gpu.ComputeMinor != 2)) ||
|
||||||
gpu.Library == "ROCm" ||
|
gpu.Library == "ROCm" ||
|
||||||
gpu.Library == "Vulkan"
|
gpu.Library == "Vulkan"
|
||||||
|
|
||||||
|
|
@ -549,12 +543,12 @@ func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) {
|
||||||
}
|
}
|
||||||
v, existing := env[envVar]
|
v, existing := env[envVar]
|
||||||
if existing {
|
if existing {
|
||||||
v = v + ","
|
v += ","
|
||||||
}
|
}
|
||||||
if d.FilterID != "" {
|
if d.FilterID != "" {
|
||||||
v = v + d.FilterID
|
v += d.FilterID
|
||||||
} else {
|
} else {
|
||||||
v = v + d.ID
|
v += d.ID
|
||||||
}
|
}
|
||||||
env[envVar] = v
|
env[envVar] = v
|
||||||
}
|
}
|
||||||
|
|
@ -594,7 +588,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
return nil, errors.New("failed to finish discovery before timeout")
|
||||||
case <-tick:
|
case <-tick:
|
||||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -606,7 +600,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// slog.Warn("failed to send request", "error", err)
|
// slog.Warn("failed to send request", "error", err)
|
||||||
if runner.HasExited() {
|
if runner.HasExited() {
|
||||||
return nil, fmt.Errorf("runner crashed")
|
return nil, errors.New("runner crashed")
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -614,7 +608,7 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo,
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusNotFound {
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
// old runner, fall back to bootstrapping model
|
// old runner, fall back to bootstrapping model
|
||||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
return nil, errors.New("llamarunner free vram reporting not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
|
|
||||||
|
|
@ -143,9 +143,9 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
case r == 0x00ad:
|
case r == 0x00ad:
|
||||||
r = 0x0143
|
r = 0x0143
|
||||||
case r <= 0x0020:
|
case r <= 0x0020:
|
||||||
r = r + 0x0100
|
r += 0x0100
|
||||||
case r >= 0x007f && r <= 0x00a0:
|
case r >= 0x007f && r <= 0x00a0:
|
||||||
r = r + 0x00a2
|
r += 0x00a2
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteRune(r)
|
sb.WriteRune(r)
|
||||||
|
|
@ -264,9 +264,9 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||||
case r == 0x0143:
|
case r == 0x0143:
|
||||||
r = 0x00ad
|
r = 0x00ad
|
||||||
case r > 0x0100 && r <= 0x0120:
|
case r > 0x0100 && r <= 0x0120:
|
||||||
r = r - 0x0100
|
r -= 0x0100
|
||||||
case r > 0x0120 && r <= 0x0142:
|
case r > 0x0120 && r <= 0x0142:
|
||||||
r = r - 0x00a2
|
r -= 0x00a2
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: not using WriteRune here because it writes the UTF-8
|
// NOTE: not using WriteRune here because it writes the UTF-8
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
|
||||||
func modelForArch(c fs.Config) (Model, error) {
|
func modelForArch(c fs.Config) (Model, error) {
|
||||||
arch := c.Architecture()
|
arch := c.Architecture()
|
||||||
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
|
||||||
arch = arch + "_embed"
|
arch += "_embed"
|
||||||
}
|
}
|
||||||
|
|
||||||
f, ok := models[arch]
|
f, ok := models[arch]
|
||||||
|
|
@ -175,9 +175,10 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
tagsCopy = append(tagsCopy, parseTag(tag))
|
tagsCopy = append(tagsCopy, parseTag(tag))
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
|
switch {
|
||||||
|
case tt == reflect.TypeFor[Base]():
|
||||||
vv.Set(reflect.ValueOf(base))
|
vv.Set(reflect.ValueOf(base))
|
||||||
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
|
case tt == reflect.TypeFor[ml.Tensor]():
|
||||||
var fn func([]Tag, string, string) [][]string
|
var fn func([]Tag, string, string) [][]string
|
||||||
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
|
||||||
if len(tags) > 0 {
|
if len(tags) > 0 {
|
||||||
|
|
@ -217,9 +218,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
|
case tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface:
|
||||||
setPointer(base, vv, tagsCopy)
|
setPointer(base, vv, tagsCopy)
|
||||||
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
|
case tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array:
|
||||||
for i := range vv.Len() {
|
for i := range vv.Len() {
|
||||||
vvv := vv.Index(i)
|
vvv := vv.Index(i)
|
||||||
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.attnKeyLen, m.ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|
@ -178,10 +178,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||||
|
|
||||||
if len(m.Layers) == gemma27BLayerCount {
|
if len(m.Layers) == gemma27BLayerCount {
|
||||||
m.Options.largeModelScaling = true
|
m.largeModelScaling = true
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
|
|
@ -202,9 +202,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||||
|
|
||||||
// final logit softcap
|
// final logit softcap
|
||||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
|
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.finalLogitSoftcap))
|
||||||
hiddenState = hiddenState.Tanh(ctx)
|
hiddenState = hiddenState.Tanh(ctx)
|
||||||
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
|
return hiddenState.Scale(ctx, float64(m.finalLogitSoftcap)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
|
||||||
|
|
@ -96,15 +96,15 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
f32s, err := m.ProcessImage(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pixelValues := ctx.Input().FromFloats(f32s,
|
pixelValues := ctx.Input().FromFloats(f32s,
|
||||||
m.ImageProcessor.imageSize,
|
m.imageSize,
|
||||||
m.ImageProcessor.imageSize,
|
m.imageSize,
|
||||||
m.ImageProcessor.numChannels,
|
m.numChannels,
|
||||||
)
|
)
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
|
|
|
||||||
|
|
@ -111,12 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
ropeBase := m.TextConfig.ropeLocalBase
|
ropeBase := m.ropeLocalBase
|
||||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
ropeBase = m.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.attnKeyLen, ropeBase, 1/m.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|
@ -166,7 +166,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
var except []int
|
var except []int
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
MultiModalProjector: newMultiModalProjector(c),
|
MultiModalProjector: newMultiModalProjector(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
@ -109,12 +109,12 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, size, err := m.ImageProcessor.ProcessImage(image)
|
f32s, size, err := m.ProcessImage(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.numChannels)
|
||||||
|
|
||||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||||
|
|
|
||||||
|
|
@ -133,7 +133,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||||
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|
||||||
// Prepare position IDs for 2D rope
|
// Prepare position IDs for 2D rope
|
||||||
positions := make([]int32, numPatches)
|
positions := make([]int32, numPatches)
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
|
|
||||||
encoderCache := kvcache.NewEncoderCache()
|
encoderCache := kvcache.NewEncoderCache()
|
||||||
encoderCache.SetConfig(ml.CacheConfig{})
|
encoderCache.SetConfig(ml.CacheConfig{})
|
||||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.Shift))
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
@ -69,7 +69,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
|
f32s, ratio, err := m.ProcessImage(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -223,8 +223,8 @@ func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, cros
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTextModel(c fs.Config) *TextModel {
|
func newTextModel(c fs.Config) *TextModel {
|
||||||
var decoderLayers []TextDecoderLayer
|
decoderLayers := make([]TextDecoderLayer, c.Uint("block_count"))
|
||||||
for i := range c.Uint("block_count") {
|
for i := range decoderLayers {
|
||||||
var textDecoderLayer TextDecoderLayer
|
var textDecoderLayer TextDecoderLayer
|
||||||
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
|
if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) {
|
||||||
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
|
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
|
||||||
|
|
@ -232,7 +232,7 @@ func newTextModel(c fs.Config) *TextModel {
|
||||||
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
|
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
|
||||||
}
|
}
|
||||||
|
|
||||||
decoderLayers = append(decoderLayers, textDecoderLayer)
|
decoderLayers[i] = textDecoderLayer
|
||||||
}
|
}
|
||||||
|
|
||||||
return &TextModel{
|
return &TextModel{
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package qwen2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
@ -130,7 +131,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
// This model currently only supports the gpt2 tokenizer
|
// This model currently only supports the gpt2 tokenizer
|
||||||
if c.String("tokenizer.ggml.model") == "llama" {
|
if c.String("tokenizer.ggml.model") == "llama" {
|
||||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
return nil, errors.New("unsupported tokenizer: llama")
|
||||||
}
|
}
|
||||||
// detect library/qwen model(s) which are incompatible
|
// detect library/qwen model(s) which are incompatible
|
||||||
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
|
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
@ -59,14 +59,13 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
f32s, grid, err := m.ProcessImage(image)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate tensor dimensions
|
// Calculate tensor dimensions
|
||||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
patchDim := m.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
|
||||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
|
||||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||||
|
|
||||||
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
|
||||||
|
|
|
||||||
|
|
@ -228,7 +228,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid)
|
||||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||||
|
|
||||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.numHeads)
|
||||||
// Apply encoder layers
|
// Apply encoder layers
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
|
||||||
|
|
||||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
return nil, nil, fmt.Errorf("failed to create patches: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return patches and grid dimensions
|
// Return patches and grid dimensions
|
||||||
|
|
|
||||||
|
|
@ -203,7 +203,7 @@ func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ model.Model = (*Model)(nil)
|
var _ model.Model = (*Model)(nil)
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tenso
|
||||||
|
|
||||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
return nil, nil, fmt.Errorf("failed to create patches: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
patchDim := p.numChannels * p.temporalPatchSize *
|
patchDim := p.numChannels * p.temporalPatchSize *
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||||
if multiStepTool && message.Role == "user" {
|
if multiStepTool && message.Role == "user" {
|
||||||
// Check if content starts with <tool_response> and ends with </tool_response>
|
// Check if content starts with <tool_response> and ends with </tool_response>
|
||||||
content := r.renderContent(message)
|
content := r.renderContent(message)
|
||||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
if !strings.HasPrefix(content, "<tool_response>") || !strings.HasSuffix(content, "</tool_response>") {
|
||||||
multiStepTool = false
|
multiStepTool = false
|
||||||
lastQueryIndex = i
|
lastQueryIndex = i
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -205,12 +205,12 @@ func (q queue) Less(i, j int) bool {
|
||||||
|
|
||||||
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||||
|
|
||||||
func (q *queue) Push(x interface{}) {
|
func (q *queue) Push(x any) {
|
||||||
item := x.(*candidate)
|
item := x.(*candidate)
|
||||||
*q = append(*q, item)
|
*q = append(*q, item)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *queue) Pop() interface{} {
|
func (q *queue) Pop() any {
|
||||||
old := *q
|
old := *q
|
||||||
n := len(old)
|
n := len(old)
|
||||||
item := old[n-1]
|
item := old[n-1]
|
||||||
|
|
@ -231,7 +231,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||||
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||||
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
return "", fmt.Errorf("failed to parse hex byte: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||||
|
|
|
||||||
|
|
@ -232,9 +232,9 @@ func NewError(code int, message string) ErrorResponse {
|
||||||
// ToUsage converts an api.ChatResponse to Usage
|
// ToUsage converts an api.ChatResponse to Usage
|
||||||
func ToUsage(r api.ChatResponse) Usage {
|
func ToUsage(r api.ChatResponse) Usage {
|
||||||
return Usage{
|
return Usage{
|
||||||
PromptTokens: r.Metrics.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.Metrics.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -326,9 +326,9 @@ func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||||
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||||
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||||
return Usage{
|
return Usage{
|
||||||
PromptTokens: r.Metrics.PromptEvalCount,
|
PromptTokens: r.PromptEvalCount,
|
||||||
CompletionTokens: r.Metrics.EvalCount,
|
CompletionTokens: r.EvalCount,
|
||||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -377,20 +377,19 @@ func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||||
|
|
||||||
// ToListCompletion converts an api.ListResponse to ListCompletion
|
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||||
func ToListCompletion(r api.ListResponse) ListCompletion {
|
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||||
var data []Model
|
c := ListCompletion{Object: "list"}
|
||||||
for _, m := range r.Models {
|
if len(r.Models) > 0 {
|
||||||
data = append(data, Model{
|
c.Data = make([]Model, len(r.Models))
|
||||||
Id: m.Name,
|
for i, m := range r.Models {
|
||||||
Object: "model",
|
c.Data[i] = Model{
|
||||||
Created: m.ModifiedAt.Unix(),
|
Id: m.Name,
|
||||||
OwnedBy: model.ParseName(m.Name).Namespace,
|
Object: "model",
|
||||||
})
|
Created: m.ModifiedAt.Unix(),
|
||||||
}
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
||||||
|
}
|
||||||
return ListCompletion{
|
}
|
||||||
Object: "list",
|
|
||||||
Data: data,
|
|
||||||
}
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||||
|
|
@ -487,19 +486,14 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
url, valid := strings.CutPrefix(url, "data:;base64,")
|
||||||
valid := false
|
if !valid {
|
||||||
// support blank mime type to match api/chat taking just unadorned base64
|
for _, t := range []string{"jpeg", "jpg", "png", "webp"} {
|
||||||
if strings.HasPrefix(url, "data:;base64,") {
|
prefix := "data:image/" + t + ";base64,"
|
||||||
url = strings.TrimPrefix(url, "data:;base64,")
|
url, valid = strings.CutPrefix(url, prefix)
|
||||||
valid = true
|
if valid {
|
||||||
}
|
break
|
||||||
for _, t := range types {
|
}
|
||||||
prefix := "data:image/" + t + ";base64,"
|
|
||||||
if strings.HasPrefix(url, prefix) {
|
|
||||||
url = strings.TrimPrefix(url, prefix)
|
|
||||||
valid = true
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
|
|
@ -78,9 +79,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
||||||
if req.Files == nil {
|
if req.Files == nil {
|
||||||
req.Files = digestMap
|
req.Files = digestMap
|
||||||
} else {
|
} else {
|
||||||
for k, v := range digestMap {
|
maps.Copy(req.Files, digestMap)
|
||||||
req.Files[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case "adapter":
|
case "adapter":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
path, err := expandPath(c.Args, relativeDir)
|
||||||
|
|
@ -371,7 +370,7 @@ func (e *ParserError) Error() string {
|
||||||
func ParseFile(r io.Reader) (*Modelfile, error) {
|
func ParseFile(r io.Reader) (*Modelfile, error) {
|
||||||
var cmd Command
|
var cmd Command
|
||||||
var curr state
|
var curr state
|
||||||
var currLine int = 1
|
currLine := 1
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var role string
|
var role string
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -326,17 +326,11 @@ MESSAGE system`,
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch tt.err.(type) {
|
|
||||||
case *ParserError:
|
|
||||||
var pErr *ParserError
|
|
||||||
if errors.As(err, &pErr) {
|
|
||||||
// got the correct type of error
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.Is(err, tt.err) {
|
if errors.Is(err, tt.err) {
|
||||||
return
|
return
|
||||||
|
} else if pErr := (*ParserError)(nil); errors.As(err, &pErr) {
|
||||||
|
// got the correct type of error
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
|
t.Fatalf("unexpected error: expected: %v, actual: %v", tt.err, err)
|
||||||
|
|
@ -1089,7 +1083,7 @@ func TestFilesForModel(t *testing.T) {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error, but got none")
|
t.Error("Expected error, but got none")
|
||||||
}
|
}
|
||||||
if tt.expectErrType != nil && err != tt.expectErrType {
|
if tt.expectErrType != nil && !errors.Is(err, tt.expectErrType) {
|
||||||
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
t.Errorf("Expected error type %v, got %v", tt.expectErrType, err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package readline
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/emirpasic/gods/v2/lists/arraylist"
|
"github.com/emirpasic/gods/v2/lists/arraylist"
|
||||||
"github.com/mattn/go-runewidth"
|
"github.com/mattn/go-runewidth"
|
||||||
|
|
@ -297,7 +298,7 @@ func (b *Buffer) drawRemaining() {
|
||||||
remaining := (remainingText[len(currLine):])
|
remaining := (remainingText[len(currLine):])
|
||||||
var totalLines int
|
var totalLines int
|
||||||
var displayLength int
|
var displayLength int
|
||||||
var lineLength int = currLineSpace
|
lineLength := currLineSpace
|
||||||
|
|
||||||
for _, c := range remaining {
|
for _, c := range remaining {
|
||||||
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
|
if displayLength == 0 || (displayLength+runewidth.RuneWidth(c))%b.LineWidth < displayLength%b.LineWidth {
|
||||||
|
|
@ -515,13 +516,13 @@ func (b *Buffer) StringN(n int) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) StringNM(n, m int) string {
|
func (b *Buffer) StringNM(n, m int) string {
|
||||||
var s string
|
var sb strings.Builder
|
||||||
if m == 0 {
|
if m == 0 {
|
||||||
m = b.Buf.Size()
|
m = b.Buf.Size()
|
||||||
}
|
}
|
||||||
for cnt := n; cnt < m; cnt++ {
|
for cnt := n; cnt < m; cnt++ {
|
||||||
c, _ := b.Buf.Get(cnt)
|
c, _ := b.Buf.Get(cnt)
|
||||||
s += string(c)
|
sb.WriteRune(c)
|
||||||
}
|
}
|
||||||
return s
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Prompt struct {
|
type Prompt struct {
|
||||||
|
|
@ -124,18 +125,19 @@ func (i *Instance) Readline() (string, error) {
|
||||||
case KeyRight:
|
case KeyRight:
|
||||||
buf.MoveRight()
|
buf.MoveRight()
|
||||||
case CharBracketedPaste:
|
case CharBracketedPaste:
|
||||||
var code string
|
var code strings.Builder
|
||||||
for range 3 {
|
for range 3 {
|
||||||
r, err = i.Terminal.Read()
|
r, err = i.Terminal.Read()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", io.EOF
|
return "", io.EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
code += string(r)
|
code.WriteRune(r)
|
||||||
}
|
}
|
||||||
if code == CharBracketedPasteStart {
|
switch code.String() {
|
||||||
|
case CharBracketedPasteStart:
|
||||||
i.Pasting = true
|
i.Pasting = true
|
||||||
} else if code == CharBracketedPasteEnd {
|
case CharBracketedPasteEnd:
|
||||||
i.Pasting = false
|
i.Pasting = false
|
||||||
}
|
}
|
||||||
case KeyDel:
|
case KeyDel:
|
||||||
|
|
|
||||||
|
|
@ -459,10 +459,7 @@ func TestLogprobsWithStopSequences(t *testing.T) {
|
||||||
|
|
||||||
origLogprobsLen := len(logprobs)
|
origLogprobsLen := len(logprobs)
|
||||||
numTokensRemoved := origLen - newLen
|
numTokensRemoved := origLen - newLen
|
||||||
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
|
||||||
if newLogprobsLen < 0 {
|
|
||||||
newLogprobsLen = 0
|
|
||||||
}
|
|
||||||
logprobs = logprobs[:newLogprobsLen]
|
logprobs = logprobs[:newLogprobsLen]
|
||||||
|
|
||||||
// Verify responses were truncated correctly
|
// Verify responses were truncated correctly
|
||||||
|
|
|
||||||
|
|
@ -39,21 +39,15 @@ func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
||||||
|
|
||||||
joined = joined[:index]
|
joined = joined[:index]
|
||||||
|
|
||||||
// Split truncated string back into pieces of original lengths
|
result := make([]string, 0, len(pieces))
|
||||||
lengths := make([]int, len(pieces))
|
|
||||||
for i, piece := range pieces {
|
|
||||||
lengths[i] = len(piece)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []string
|
|
||||||
tokenTruncated := false
|
tokenTruncated := false
|
||||||
start := 0
|
start := 0
|
||||||
for _, length := range lengths {
|
for _, piece := range pieces {
|
||||||
if start >= len(joined) {
|
if start >= len(joined) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
end := start + length
|
end := start + len(piece)
|
||||||
if end > len(joined) {
|
if end > len(joined) {
|
||||||
end = len(joined)
|
end = len(joined)
|
||||||
tokenTruncated = true
|
tokenTruncated = true
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []by
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(data) <= 0 {
|
if len(data) == 0 {
|
||||||
return nil, errors.New("received zero length image")
|
return nil, errors.New("received zero length image")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package llamarunner
|
package llamarunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
|
@ -18,7 +19,7 @@ func TestImageCache(t *testing.T) {
|
||||||
|
|
||||||
// Empty cache
|
// Empty cache
|
||||||
result, err := cache.findImage(0x5adb61d31933a946)
|
result, err := cache.findImage(0x5adb61d31933a946)
|
||||||
if err != errImageNotFound {
|
if !errors.Is(err, errImageNotFound) {
|
||||||
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -577,10 +577,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
if seq.logprobs {
|
if seq.logprobs {
|
||||||
origLogprobsLen := len(seq.pendingLogprobs)
|
origLogprobsLen := len(seq.pendingLogprobs)
|
||||||
numTokensRemoved := origLen - newLen
|
numTokensRemoved := origLen - newLen
|
||||||
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
|
||||||
if newLogprobsLen < 0 {
|
|
||||||
newLogprobsLen = 0
|
|
||||||
}
|
|
||||||
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
|
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -998,7 +995,6 @@ func Execute(args []string) error {
|
||||||
|
|
||||||
log.Println("Server listening on", addr)
|
log.Println("Server listening on", addr)
|
||||||
if err := httpServer.Serve(listener); err != nil {
|
if err := httpServer.Serve(listener); err != nil {
|
||||||
log.Fatal("server error:", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package ollamarunner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"slices"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -511,7 +510,7 @@ type mockCache struct {
|
||||||
// Implement only the methods needed for the test
|
// Implement only the methods needed for the test
|
||||||
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||||
if m.shouldFail {
|
if m.shouldFail {
|
||||||
return fmt.Errorf("mock cache removal error")
|
return errors.New("mock cache removal error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -801,10 +801,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||||
if seq.logprobs {
|
if seq.logprobs {
|
||||||
origLogprobsLen := len(seq.pendingLogprobs)
|
origLogprobsLen := len(seq.pendingLogprobs)
|
||||||
numTokensRemoved := origLen - newLen
|
numTokensRemoved := origLen - newLen
|
||||||
newLogprobsLen := origLogprobsLen - numTokensRemoved
|
newLogprobsLen := max(origLogprobsLen-numTokensRemoved, 0)
|
||||||
if newLogprobsLen < 0 {
|
|
||||||
newLogprobsLen = 0
|
|
||||||
}
|
|
||||||
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
|
seq.pendingLogprobs = seq.pendingLogprobs[:newLogprobsLen]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1242,7 +1239,7 @@ func (s *Server) loadModel() {
|
||||||
s.progress = progress
|
s.progress = progress
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Errorf("failed to load model: %v", err))
|
panic(fmt.Errorf("failed to load model: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
s.status = llm.ServerStatusReady
|
s.status = llm.ServerStatusReady
|
||||||
|
|
@ -1432,7 +1429,6 @@ func Execute(args []string) error {
|
||||||
|
|
||||||
log.Println("Server listening on", addr)
|
log.Println("Server listening on", addr)
|
||||||
if err := httpServer.Serve(listener); err != nil {
|
if err := httpServer.Serve(listener); err != nil {
|
||||||
log.Fatal("server error:", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ func temperature(ts []token, temp float32) {
|
||||||
// Ensure temperature clipping near 0 to avoid numerical instability
|
// Ensure temperature clipping near 0 to avoid numerical instability
|
||||||
temp = max(temp, 1e-7)
|
temp = max(temp, 1e-7)
|
||||||
for i := range ts {
|
for i := range ts {
|
||||||
ts[i].value = ts[i].value / temp
|
ts[i].value /= temp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ func (r registryChallenge) URL() (*url.URL, error) {
|
||||||
|
|
||||||
values := redirectURL.Query()
|
values := redirectURL.Query()
|
||||||
values.Add("service", r.Service)
|
values.Add("service", r.Service)
|
||||||
for _, s := range strings.Split(r.Scope, " ") {
|
for s := range strings.SplitSeq(r.Scope, " ") {
|
||||||
values.Add("scope", s)
|
values.Add("scope", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -57,7 +57,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
|
||||||
}
|
}
|
||||||
|
|
||||||
sha256sum := sha256.Sum256(nil)
|
sha256sum := sha256.Sum256(nil)
|
||||||
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
|
data := fmt.Appendf(nil, "%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:]))))
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
signature, err := auth.Sign(ctx, data)
|
signature, err := auth.Sign(ctx, data)
|
||||||
|
|
@ -75,7 +75,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
|
||||||
|
|
||||||
body, err := io.ReadAll(response.Body)
|
body, err := io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
|
return "", fmt.Errorf("%d: %w", response.StatusCode, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
|
|
|
||||||
|
|
@ -386,7 +386,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||||
}
|
}
|
||||||
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
// Path is likely outside the root
|
// Path is likely outside the root
|
||||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
return nil, fmt.Errorf("%w: %w: %s", errFilePath, err, fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
blobPath, err := GetBlobsPath(digest)
|
blobPath, err := GetBlobsPath(digest)
|
||||||
|
|
@ -456,15 +456,15 @@ func kvFromLayers(baseLayers []*layerGGML) (ggml.KV, error) {
|
||||||
return l.KV(), nil
|
return l.KV(), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ggml.KV{}, fmt.Errorf("no base model was found")
|
return ggml.KV{}, errors.New("no base model was found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
|
||||||
var layers []Layer
|
layers := make([]Layer, len(baseLayers))
|
||||||
for _, layer := range baseLayers {
|
for i, layer := range baseLayers {
|
||||||
if layer.GGML != nil {
|
if layer.GGML != nil {
|
||||||
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
|
||||||
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
if quantType != "" && layer.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
|
||||||
want, err := ggml.ParseFileType(quantType)
|
want, err := ggml.ParseFileType(quantType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -480,13 +480,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
config.ModelFormat = cmp.Or(config.ModelFormat, layer.GGML.Name())
|
config.ModelFormat = cmp.Or(config.ModelFormat, layer.Name())
|
||||||
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
|
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
|
||||||
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
|
||||||
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
|
||||||
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
|
||||||
}
|
}
|
||||||
layers = append(layers, layer.Layer)
|
layers[i] = layer.Layer
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Template != "" {
|
if r.Template != "" {
|
||||||
|
|
@ -678,10 +678,10 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
|
||||||
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
func setTemplate(layers []Layer, t string) ([]Layer, error) {
|
||||||
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
layers = removeLayer(layers, "application/vnd.ollama.image.template")
|
||||||
if _, err := template.Parse(t); err != nil {
|
if _, err := template.Parse(t); err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
|
||||||
}
|
}
|
||||||
if _, err := template.Parse(t); err != nil {
|
if _, err := template.Parse(t); err != nil {
|
||||||
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
|
return nil, fmt.Errorf("%w: %w", errBadTemplate, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
blob := strings.NewReader(t)
|
blob := strings.NewReader(t)
|
||||||
|
|
|
||||||
|
|
@ -640,7 +640,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||||
|
|
||||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("pull model manifest: %s", err)
|
return fmt.Errorf("pull model manifest: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []Layer
|
var layers []Layer
|
||||||
|
|
@ -786,7 +786,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
return nil, fmt.Errorf("%d: %w", resp.StatusCode, err)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||||
default:
|
default:
|
||||||
|
|
|
||||||
|
|
@ -438,7 +438,7 @@ func (w *checkWriter) Write(p []byte) (int, error) {
|
||||||
// last write. check hash.
|
// last write. check hash.
|
||||||
sum := w.h.Sum(nil)
|
sum := w.h.Sum(nil)
|
||||||
if !bytes.Equal(sum, w.d.sum[:]) {
|
if !bytes.Equal(sum, w.d.sum[:]) {
|
||||||
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
|
return 0, w.seterr(errors.New("file content changed underfoot"))
|
||||||
}
|
}
|
||||||
if w.testHookBeforeFinalWrite != nil {
|
if w.testHookBeforeFinalWrite != nil {
|
||||||
w.testHookBeforeFinalWrite(w.f)
|
w.testHookBeforeFinalWrite(w.f)
|
||||||
|
|
|
||||||
|
|
@ -84,8 +84,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
|
||||||
|
|
||||||
// TODO(bmizerany): Print platform-specific instructions or
|
// TODO(bmizerany): Print platform-specific instructions or
|
||||||
// link to docs on that topic.
|
// link to docs on that topic.
|
||||||
lines := strings.Split(volumeHint, "\n")
|
for line := range strings.SplitSeq(volumeHint, "\n") {
|
||||||
for _, line := range lines {
|
|
||||||
t.Skip(line)
|
t.Skip(line)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ func (d Digest) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Digest) Short() string {
|
func (d Digest) Short() string {
|
||||||
return fmt.Sprintf("%x", d.sum[:4])
|
return hex.EncodeToString(d.sum[:4])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d Digest) Sum() [32]byte {
|
func (d Digest) Sum() [32]byte {
|
||||||
|
|
|
||||||
|
|
@ -1184,11 +1184,11 @@ func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
||||||
}
|
}
|
||||||
start, err := strconv.ParseInt(startPart, 10, 64)
|
start, err := strconv.ParseInt(startPart, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %w", s, err)
|
||||||
}
|
}
|
||||||
end, err := strconv.ParseInt(endPart, 10, 64)
|
end, err := strconv.ParseInt(endPart, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %w", s, err)
|
||||||
}
|
}
|
||||||
if start > end {
|
if start > end {
|
||||||
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ var junkName Name
|
||||||
|
|
||||||
func BenchmarkParseName(b *testing.B) {
|
func BenchmarkParseName(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
for range b.N {
|
for b.Loop() {
|
||||||
junkName = Parse("h/n/m:t")
|
junkName = Parse("h/n/m:t")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -187,15 +187,15 @@ func (w *relayWriter) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *relayWriter) awaitTurn() (ok bool) {
|
func (w *relayWriter) awaitTurn() (ok bool) {
|
||||||
if t.ready {
|
if w.ready {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-t.t.Ready():
|
case <-w.t.Ready():
|
||||||
t.ready = true
|
w.ready = true
|
||||||
return true
|
return true
|
||||||
case <-t.q.closed():
|
case <-w.q.closed():
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||||
type progressUpdateJSON struct {
|
type progressUpdateJSON struct {
|
||||||
Error string `json:"error,omitempty,omitzero"`
|
Error string `json:"error,omitempty,omitzero"`
|
||||||
Status string `json:"status,omitempty,omitzero"`
|
Status string `json:"status,omitempty,omitzero"`
|
||||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
Digest blob.Digest `json:"digest,omitzero"`
|
||||||
Total int64 `json:"total,omitempty,omitzero"`
|
Total int64 `json:"total,omitempty,omitzero"`
|
||||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
return "", nil, errors.New("this model only supports one image while more than one image requested")
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefix string
|
var prefix strings.Builder
|
||||||
prompt := msg.Content
|
prompt := msg.Content
|
||||||
|
|
||||||
for _, i := range msg.Images {
|
for _, i := range msg.Images {
|
||||||
|
|
@ -85,14 +85,14 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||||
|
|
||||||
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
|
||||||
if !strings.Contains(prompt, "[img]") {
|
if !strings.Contains(prompt, "[img]") {
|
||||||
prefix += imgTag
|
prefix.WriteString(imgTag)
|
||||||
} else {
|
} else {
|
||||||
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
images = append(images, imgData)
|
images = append(images, imgData)
|
||||||
}
|
}
|
||||||
msgs[currMsgIdx+cnt].Content = prefix + prompt
|
msgs[currMsgIdx+cnt].Content = prefix.String() + prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate any messages that do not fit into the context window
|
// truncate any messages that do not fit into the context window
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
@ -238,7 +239,7 @@ func TestChatPrompt(t *testing.T) {
|
||||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||||
if tt.error == nil && err != nil {
|
if tt.error == nil && err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
} else if tt.error != nil && err != tt.error {
|
} else if tt.error != nil && !errors.Is(err, tt.error) {
|
||||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) {
|
||||||
data, err := io.ReadAll(sr)
|
data, err := io.ReadAll(sr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
|
||||||
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
|
return 0, fmt.Errorf("unable to read tensor %s from %s: %w", q.from.Name, q.Name(), err)
|
||||||
}
|
}
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
newType := fsggml.TensorType(q.to.Kind)
|
newType := fsggml.TensorType(q.to.Kind)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"math"
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -129,7 +130,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
|
if slices.Contains(model.Config.ModelFamilies, "mllama") && len(model.ProjectorPaths) > 0 {
|
||||||
return nil, nil, nil, fmt.Errorf("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
|
return nil, nil, nil, errors.New("'llama3.2-vision' is no longer compatible with your version of Ollama and has been replaced by a newer version. To re-download, run 'ollama pull llama3.2-vision'")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := model.CheckCapabilities(caps...); err != nil {
|
if err := model.CheckCapabilities(caps...); err != nil {
|
||||||
|
|
@ -361,11 +362,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||||
if req.Think == nil {
|
if req.Think == nil {
|
||||||
req.Think = &api.ThinkValue{Value: true}
|
req.Think = &api.ThinkValue{Value: true}
|
||||||
}
|
}
|
||||||
} else {
|
} else if req.Think != nil && req.Think.Bool() {
|
||||||
if req.Think != nil && req.Think.Bool() {
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||||
|
|
@ -649,10 +648,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
truncate := true
|
truncate := req.Truncate == nil || *req.Truncate
|
||||||
if req.Truncate != nil && !*req.Truncate {
|
|
||||||
truncate = false
|
|
||||||
}
|
|
||||||
|
|
||||||
var input []string
|
var input []string
|
||||||
|
|
||||||
|
|
@ -825,9 +821,9 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var e []float64
|
e := make([]float64, len(embedding))
|
||||||
for _, v := range embedding {
|
for i, v := range embedding {
|
||||||
e = append(e, float64(v))
|
e[i] = float64(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
resp := api.EmbeddingResponse{
|
||||||
|
|
@ -1139,9 +1135,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||||
if m.Options == nil {
|
if m.Options == nil {
|
||||||
m.Options = make(map[string]any)
|
m.Options = make(map[string]any)
|
||||||
}
|
}
|
||||||
for k, v := range req.Options {
|
maps.Copy(m.Options, req.Options)
|
||||||
m.Options[k] = v
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
|
|
@ -1212,7 +1206,7 @@ func (s *Server) ListHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
models := []api.ListModelResponse{}
|
models := make([]api.ListModelResponse, 0, len(ms))
|
||||||
for n, m := range ms {
|
for n, m := range ms {
|
||||||
var cf ConfigV2
|
var cf ConfigV2
|
||||||
|
|
||||||
|
|
@ -1811,13 +1805,13 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||||
ExpiresAt: v.expiresAt,
|
ExpiresAt: v.expiresAt,
|
||||||
}
|
}
|
||||||
if v.Options != nil {
|
if v.Options != nil {
|
||||||
mr.ContextLength = v.Options.NumCtx
|
mr.ContextLength = v.NumCtx
|
||||||
}
|
}
|
||||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||||
// possible that it will be set to the unix epoch. For those cases, just
|
// possible that it will be set to the unix epoch. For those cases, just
|
||||||
// calculate the time w/ the sessionDuration instead.
|
// calculate the time w/ the sessionDuration instead.
|
||||||
var epoch time.Time
|
var epoch time.Time
|
||||||
if v.expiresAt == epoch {
|
if v.expiresAt.Equal(epoch) {
|
||||||
mr.ExpiresAt = time.Now().Add(v.sessionDuration)
|
mr.ExpiresAt = time.Now().Add(v.sessionDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2000,11 +1994,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||||
if req.Think == nil {
|
if req.Think == nil {
|
||||||
req.Think = &api.ThinkValue{Value: true}
|
req.Think = &api.ThinkValue{Value: true}
|
||||||
}
|
}
|
||||||
} else {
|
} else if req.Think != nil && req.Think.Bool() {
|
||||||
if req.Think != nil && req.Think.Bool() {
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||||
|
|
|
||||||
|
|
@ -196,11 +196,9 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||||
}
|
}
|
||||||
} else {
|
} else if w.Code != http.StatusOK {
|
||||||
// When debug is disabled, it should attempt normal processing
|
// When debug is disabled, it should attempt normal processing
|
||||||
if w.Code != http.StatusOK {
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -401,11 +399,9 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||||
}
|
}
|
||||||
} else {
|
} else if w.Code != http.StatusOK {
|
||||||
// When debug is disabled, it should attempt normal processing
|
// When debug is disabled, it should attempt normal processing
|
||||||
if w.Code != http.StatusOK {
|
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||||
t.Fatalf("expected status 200, got %d", w.Code)
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.CompletionResponse.Content = "Hi!"
|
mock.Content = "Hi!"
|
||||||
|
|
||||||
t.Run("chat-like flow uses renderer", func(t *testing.T) {
|
t.Run("chat-like flow uses renderer", func(t *testing.T) {
|
||||||
// Test that when using messages (chat-like flow), the built-in renderer is used
|
// Test that when using messages (chat-like flow), the built-in renderer is used
|
||||||
|
|
@ -109,12 +109,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||||
|
|
||||||
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
|
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
|
||||||
// When messages are built internally from prompt, it should use the renderer
|
// When messages are built internally from prompt, it should use the renderer
|
||||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
if !strings.Contains(mock.Prompt, "<|im_start|>") {
|
||||||
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.Prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
|
if !strings.Contains(mock.Prompt, "<|im_end|>") {
|
||||||
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.Prompt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -132,12 +132,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should contain the system message and use renderer format
|
// Should contain the system message and use renderer format
|
||||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
|
if !strings.Contains(mock.Prompt, "<|im_start|>system") {
|
||||||
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
|
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.Prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
|
if !strings.Contains(mock.Prompt, "You are a helpful coding assistant.") {
|
||||||
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
|
t.Errorf("expected prompt to contain system message content, got: %s", mock.Prompt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -155,12 +155,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should NOT use the renderer format when custom template is provided
|
// Should NOT use the renderer format when custom template is provided
|
||||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
if strings.Contains(mock.Prompt, "<|im_start|>") {
|
||||||
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
|
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.Prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should just be the raw prompt from the template
|
// Should just be the raw prompt from the template
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "Write a hello world function"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -191,12 +191,12 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should NOT use the renderer format when suffix is provided
|
// Should NOT use the renderer format when suffix is provided
|
||||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
if strings.Contains(mock.Prompt, "<|im_start|>") {
|
||||||
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
|
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.Prompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should use the suffix template format
|
// Should use the suffix template format
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
||||||
for range strings.Fields(s) {
|
for range strings.FieldsSeq(s) {
|
||||||
tokens = append(tokens, len(tokens))
|
tokens = append(tokens, len(tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -378,7 +378,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.CompletionResponse.Content = "Hi!"
|
mock.Content = "Hi!"
|
||||||
t.Run("messages", func(t *testing.T) {
|
t.Run("messages", func(t *testing.T) {
|
||||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
Model: "test",
|
Model: "test",
|
||||||
|
|
@ -392,7 +392,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "user: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -422,14 +422,14 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
||||||
})
|
})
|
||||||
|
|
||||||
mock.CompletionResponse.Content = "Abra kadabra!"
|
mock.Content = "Abra kadabra!"
|
||||||
t.Run("messages with system", func(t *testing.T) {
|
t.Run("messages with system", func(t *testing.T) {
|
||||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
Model: "test-system",
|
Model: "test-system",
|
||||||
|
|
@ -444,7 +444,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -467,7 +467,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -985,7 +985,7 @@ func TestGenerate(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mock.CompletionResponse.Content = "Hi!"
|
mock.Content = "Hi!"
|
||||||
t.Run("prompt", func(t *testing.T) {
|
t.Run("prompt", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
Model: "test",
|
Model: "test",
|
||||||
|
|
@ -997,7 +997,7 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "User: Hello! "); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1025,14 +1025,14 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
||||||
})
|
})
|
||||||
|
|
||||||
mock.CompletionResponse.Content = "Abra kadabra!"
|
mock.Content = "Abra kadabra!"
|
||||||
t.Run("prompt with system", func(t *testing.T) {
|
t.Run("prompt with system", func(t *testing.T) {
|
||||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||||
Model: "test-system",
|
Model: "test-system",
|
||||||
|
|
@ -1045,7 +1045,7 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1067,7 +1067,7 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1097,7 +1097,7 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -1112,7 +1112,7 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "def add("); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -1129,7 +1129,7 @@ func TestGenerate(t *testing.T) {
|
||||||
t.Errorf("expected status 200, got %d", w.Code)
|
t.Errorf("expected status 200, got %d", w.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
if diff := cmp.Diff(mock.Prompt, "Help me write tests."); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue