Compare commits

..

20 Commits

Author SHA1 Message Date
Eva Ho
68a3414761 fix test 2026-01-05 13:24:41 -05:00
Eva Ho
9a5c14c58b address comments 2026-01-05 09:38:44 -05:00
Eva Ho
391fb88bce address comment 2026-01-05 09:38:44 -05:00
Eva Ho
75500c8855 address comment 2026-01-05 09:38:44 -05:00
Eva Ho
e55fbf2475 fix: gofmt formatting in updater_test.go 2026-01-05 09:38:44 -05:00
Eva Ho
c6f941adb3 fix test 2026-01-05 09:38:44 -05:00
Eva Ho
0eb320e74c fix format 2026-01-05 09:38:44 -05:00
Eva Ho
880b4f95b4 fix test 2026-01-05 09:38:44 -05:00
Eva Ho
ba25f4a898 fix test 2026-01-05 09:38:44 -05:00
Eva Ho
dc573715c4 clean up 2026-01-05 09:38:44 -05:00
Eva Ho
5a5d3260f4 fix behaviour when switching between enabled and disabled 2026-01-05 09:38:44 -05:00
Eva Ho
cf7e5e88bc fix test 2026-01-05 09:38:44 -05:00
Eva Ho
e76abac24e app: add upgrade configuration to settings page 2026-01-05 09:38:44 -05:00
Harry V. Kiselev
d087e46bd1 docs/capabilities/vision: fix curl related code snippet (#13615) 2026-01-03 17:27:46 -05:00
lif
37f6f3af24 server: return error when embedding contains NaN or Inf values (#13599)
The normalize function now checks for NaN and Inf values in the
embedding vector before processing. This prevents JSON encoding
failures when models produce invalid floating-point values.

Fixes #13572

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-03 02:20:12 -05:00
Nhan Nguyen
e1bdc23dd2 docs: fix tool name mismatch and trailing commas in api.md example (#13559)
The tool calling example used "get_temperature" for tool_calls but
defined the tool as "get_weather". Also removed trailing commas that
made the JSON invalid.

Fixes #13031
2026-01-03 02:14:53 -05:00
lif
2e78653ff9 app/ui: add swift syntax highlighting support (#13574)
Fixes #13476

Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-03 02:12:08 -05:00
lif
f5f74e12c1 docs: add version note for /v1/responses API (#13596)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-03 01:58:20 -05:00
Vallabh Mahajan
18fdcc94e5 docs: fix broken .md links and render issues (#13550) 2025-12-23 12:44:55 -05:00
Daniel Hiltgen
7ad036992f amd: use GTT on iGPUs on linux (#13196)
On Linux, look at the GTT memory information for iGPUs.
2025-12-23 09:30:05 -08:00
36 changed files with 605 additions and 1628 deletions

View File

@@ -209,9 +209,6 @@ func main() {
st := &store.Store{}
// Initialize native settings with store
SetSettingsStore(st)
// Enable CORS in development mode
if devMode {
os.Setenv("OLLAMA_CORS", "1")
@@ -256,27 +253,28 @@ func main() {
done <- osrv.Run(octx)
}()
restartServer := func() {
ocancel()
<-done
octx, ocancel = context.WithCancel(ctx)
go func() {
done <- osrv.Run(octx)
}()
}
upd := &updater.Updater{Store: st}
uiServer := ui.Server{
Token: token,
Restart: restartServer,
Token: token,
Restart: func() {
ocancel()
<-done
octx, ocancel = context.WithCancel(ctx)
go func() {
done <- osrv.Run(octx)
}()
},
Store: st,
ToolRegistry: toolRegistry,
Dev: devMode,
Logger: slog.Default(),
Updater: upd,
UpdateAvailableFunc: func() {
UpdateAvailable("")
},
}
// Set restart callback for native settings
SetRestartCallback(restartServer)
srv := &http.Server{
Handler: uiServer.Handler(),
}
@@ -292,8 +290,13 @@ func main() {
slog.Debug("background desktop server done")
}()
updater := &updater.Updater{Store: st}
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
// Check for pending updates on startup (show tray notification if update is ready)
if updater.IsUpdatePending() {
slog.Debug("update pending on startup, showing tray notification")
UpdateAvailable("")
}
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
if err != nil {
@@ -356,6 +359,18 @@ func startHiddenTasks() {
// CLI triggered app startup use-case
slog.Info("deferring pending update for fast startup")
} else {
// Check if auto-update is enabled before automatically upgrading
st := &store.Store{}
settings, err := st.Settings()
if err != nil {
slog.Warn("failed to load settings for upgrade check", "error", err)
} else if !settings.AutoUpdateEnabled {
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
// Still show tray notification so user knows update is ready
UpdateAvailable("")
return
}
if err := updater.DoUpgradeAtStartup(); err != nil {
slog.Info("unable to perform upgrade at startup", "error", err)
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization

View File

@@ -1,6 +1,5 @@
#import "app_darwin.h"
#import "menu.h"
#import "settings_darwin.h"
#import "../../updater/updater_darwin.h"
#import <AppKit/AppKit.h>
#import <Cocoa/Cocoa.h>
@@ -253,7 +252,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)settingsUI {
openNativeSettings();
[self uiRequest:@"/settings"];
}
- (void)openUI {

View File

@@ -157,6 +157,10 @@ func UpdateAvailable(ver string) error {
return app.t.UpdateAvailable(ver)
}
func ClearUpdateAvailable() error {
return app.t.ClearUpdateAvailable()
}
func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
var err error
app.shutdown = shutdown

View File

@@ -1,438 +0,0 @@
//go:build darwin
package main
/*
#cgo CFLAGS: -x objective-c
#cgo LDFLAGS: -framework Cocoa
#include <stdlib.h>
#include "settings_darwin.h"
*/
import "C"
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"encoding/pem"
"fmt"
"log/slog"
"net/http"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"
"unsafe"
"golang.org/x/crypto/ssh"
appauth "github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
)
// settingsStore is a reference to the app's store for settings
var settingsStore *store.Store
// SetSettingsStore sets the store reference for settings callbacks
func SetSettingsStore(s *store.Store) {
settingsStore = s
}
//export getSettingsExpose
func getSettingsExpose() C.bool {
if settingsStore == nil {
return C.bool(false)
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.bool(false)
}
return C.bool(settings.Expose)
}
//export setSettingsExpose
func setSettingsExpose(expose C.bool) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.Expose = bool(expose)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
//export getSettingsBrowser
func getSettingsBrowser() C.bool {
if settingsStore == nil {
return C.bool(false)
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.bool(false)
}
return C.bool(settings.Browser)
}
//export setSettingsBrowser
func setSettingsBrowser(browser C.bool) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.Browser = bool(browser)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
//export getSettingsModels
func getSettingsModels() *C.char {
if settingsStore == nil {
return C.CString(envconfig.Models())
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.CString(envconfig.Models())
}
if settings.Models == "" {
return C.CString(envconfig.Models())
}
return C.CString(settings.Models)
}
//export setSettingsModels
func setSettingsModels(path *C.char) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.Models = C.GoString(path)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
//export getSettingsContextLength
func getSettingsContextLength() C.int {
if settingsStore == nil {
return C.int(4096)
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return C.int(4096)
}
if settings.ContextLength <= 0 {
return C.int(4096)
}
return C.int(settings.ContextLength)
}
//export setSettingsContextLength
func setSettingsContextLength(length C.int) {
if settingsStore == nil {
return
}
settings, err := settingsStore.Settings()
if err != nil {
slog.Error("failed to get settings", "error", err)
return
}
settings.ContextLength = int(length)
if err := settingsStore.SetSettings(settings); err != nil {
slog.Error("failed to save settings", "error", err)
}
}
// restartCallback is set by the app to restart the ollama server
var restartCallback func()
// SetRestartCallback sets the function to call when settings change requires a restart
func SetRestartCallback(cb func()) {
restartCallback = cb
}
//export restartOllamaServer
func restartOllamaServer() {
if restartCallback != nil {
slog.Info("restarting ollama server due to settings change")
go restartCallback()
}
}
// hasOllamaKey checks if the user has an Ollama key file
func hasOllamaKey() bool {
home, err := os.UserHomeDir()
if err != nil {
return false
}
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
_, err = os.Stat(keyPath)
return err == nil
}
// ensureKeypair generates a new keypair if one doesn't exist
func ensureKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
return err
}
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
// Check if key already exists
if _, err := os.Stat(privKeyPath); err == nil {
return nil // Key exists
}
// Generate new keypair
slog.Info("generating new keypair for ollama account")
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("failed to generate key: %w", err)
}
// Marshal private key
privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "")
if err != nil {
return fmt.Errorf("failed to marshal private key: %w", err)
}
// Ensure directory exists
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
// Write private key
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600); err != nil {
return fmt.Errorf("failed to write private key: %w", err)
}
// Write public key
sshPubKey, err := ssh.NewPublicKey(pubKey)
if err != nil {
return fmt.Errorf("failed to create ssh public key: %w", err)
}
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
if err := os.WriteFile(pubKeyPath, pubKeyBytes, 0o644); err != nil {
return fmt.Errorf("failed to write public key: %w", err)
}
slog.Info("keypair generated successfully")
return nil
}
// userResponse matches the API response from ollama.com/api/me
type userResponse struct {
Name string `json:"name"`
Email string `json:"email"`
Plan string `json:"plan"`
AvatarURL string `json:"avatarurl"`
}
// fetchUserFromAPI fetches user data from ollama.com using signed request
func fetchUserFromAPI() (*userResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
signString := fmt.Sprintf("POST,/api/me?ts=%s", timestamp)
signature, err := auth.Sign(ctx, []byte(signString))
if err != nil {
return nil, fmt.Errorf("failed to sign request: %w", err)
}
endpoint := fmt.Sprintf("https://ollama.com/api/me?ts=%s", timestamp)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to call ollama.com: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
}
var user userResponse
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// Make avatar URL absolute
if user.AvatarURL != "" && !strings.HasPrefix(user.AvatarURL, "http") {
user.AvatarURL = "https://ollama.com/" + user.AvatarURL
}
// Cache the avatar URL
cachedAvatarURL = user.AvatarURL
// Cache the user data
if settingsStore != nil {
storeUser := store.User{
Name: user.Name,
Email: user.Email,
Plan: user.Plan,
}
if err := settingsStore.SetUser(storeUser); err != nil {
slog.Warn("failed to cache user", "error", err)
}
}
return &user, nil
}
//export getAccountName
func getAccountName() *C.char {
// Only return cached data - never block on network
if settingsStore == nil {
return C.CString("")
}
user, err := settingsStore.User()
if err != nil || user == nil {
return C.CString("")
}
return C.CString(user.Name)
}
// cachedAvatarURL stores the avatar URL from the last API fetch
var cachedAvatarURL string
//export getAccountAvatarURL
func getAccountAvatarURL() *C.char {
return C.CString(cachedAvatarURL)
}
//export getAccountEmail
func getAccountEmail() *C.char {
if settingsStore != nil {
user, err := settingsStore.User()
if err == nil && user != nil {
return C.CString(user.Email)
}
}
return C.CString("")
}
//export getAccountPlan
func getAccountPlan() *C.char {
if settingsStore != nil {
user, err := settingsStore.User()
if err == nil && user != nil {
return C.CString(user.Plan)
}
}
return C.CString("")
}
//export signOutAccount
func signOutAccount() {
if settingsStore != nil {
if err := settingsStore.ClearUser(); err != nil {
slog.Error("failed to clear user", "error", err)
}
}
// Also remove the key file
home, err := os.UserHomeDir()
if err != nil {
slog.Error("failed to get home dir", "error", err)
return
}
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
if err := os.Remove(keyPath); err != nil && !os.IsNotExist(err) {
slog.Error("failed to remove key file", "error", err)
}
}
//export openConnectUrl
func openConnectUrl() {
// Ensure keypair exists (generate if needed)
if err := ensureKeypair(); err != nil {
slog.Error("failed to ensure keypair", "error", err)
// Fallback to basic connect page
cmd := exec.Command("open", "https://ollama.com/connect")
cmd.Start()
return
}
// Build connect URL with public key
connectURL, err := appauth.BuildConnectURL("https://ollama.com")
if err != nil {
slog.Error("failed to build connect URL", "error", err)
// Fallback to basic connect page
connectURL = "https://ollama.com/connect"
}
cmd := exec.Command("open", connectURL)
if err := cmd.Start(); err != nil {
slog.Error("failed to open connect URL", "error", err)
}
}
//export refreshAccountFromAPI
func refreshAccountFromAPI() {
if !hasOllamaKey() {
return
}
_, err := fetchUserFromAPI()
if err != nil {
slog.Debug("failed to refresh account", "error", err)
}
}
//export prefetchAccountData
func prefetchAccountData() {
// Run in background goroutine to not block app startup
go func() {
if !hasOllamaKey() {
return
}
_, err := fetchUserFromAPI()
if err != nil {
slog.Debug("failed to prefetch account data", "error", err)
} else {
slog.Debug("prefetched account data successfully")
}
}()
}
// OpenNativeSettings opens the native settings window
func OpenNativeSettings() {
C.openNativeSettings()
}
// Ensure the CString is freed (caller must free)
func freeCString(s *C.char) {
C.free(unsafe.Pointer(s))
}

View File

@@ -1,38 +0,0 @@
#import <Cocoa/Cocoa.h>
@interface SettingsWindowController : NSWindowController <NSWindowDelegate>
// General tab
@property(nonatomic, strong) NSButton *exposeCheckbox;
@property(nonatomic, strong) NSButton *browserCheckbox;
@property(nonatomic, strong) NSSlider *contextLengthSlider;
// Models tab
@property(nonatomic, strong) NSPathControl *modelsPathControl;
@property(nonatomic, strong) NSButton *modelsPathButton;
// Account tab
@property(nonatomic, strong) NSView *avatarView;
@property(nonatomic, strong) NSTextField *avatarInitialLabel;
@property(nonatomic, strong) NSImageView *avatarImageView;
@property(nonatomic, strong) NSTextField *accountNameLabel;
@property(nonatomic, strong) NSTextField *accountEmailLabel;
@property(nonatomic, strong) NSButton *manageButton;
@property(nonatomic, strong) NSButton *signOutButton;
@property(nonatomic, strong) NSButton *signInButton;
@property(nonatomic, strong) NSView *signedInContainer;
@property(nonatomic, strong) NSView *signedOutContainer;
// Plan section
@property(nonatomic, strong) NSView *planContainer;
@property(nonatomic, strong) NSTextField *planNameLabel;
@property(nonatomic, strong) NSButton *upgradeButton;
@property(nonatomic, strong) NSButton *viewUsageButton;
+ (instancetype)sharedController;
- (void)showSettings;
@end
// Go callbacks for settings
void openNativeSettings(void);

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +0,0 @@
//go:build windows
package main
import "github.com/ollama/ollama/app/store"
// SetSettingsStore sets the store reference for settings callbacks (stub for Windows)
func SetSettingsStore(s *store.Store) {
// TODO: Implement Windows native settings
}
// SetRestartCallback sets the function to call when settings change requires a restart (stub for Windows)
func SetRestartCallback(cb func()) {
// TODO: Implement Windows native settings
}

View File

@@ -9,12 +9,12 @@ import (
"strings"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
_ "github.com/mattn/go-sqlite3"
)
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 12
const currentSchemaVersion = 13
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -85,6 +85,7 @@ func (db *database) init() error {
think_enabled BOOLEAN NOT NULL DEFAULT 0,
think_level TEXT NOT NULL DEFAULT '',
remote TEXT NOT NULL DEFAULT '', -- deprecated
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
schema_version INTEGER NOT NULL DEFAULT %d
);
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v11 to v12: %w", err)
}
version = 12
case 12:
// add auto_update_enabled column to settings table
if err := db.migrateV12ToV13(); err != nil {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
return nil
}
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
func (db *database) migrateV12ToV13() error {
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add auto_update_enabled column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`
@@ -482,19 +504,11 @@ func (db *database) cleanupOrphanedData() error {
}
func duplicateColumnError(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "duplicate column name")
}
return false
return err != nil && strings.Contains(err.Error(), "duplicate column name")
}
func columnNotExists(err error) bool {
if sqlite3Err, ok := err.(sqlite3.Error); ok {
return sqlite3Err.Code == sqlite3.ErrError &&
strings.Contains(sqlite3Err.Error(), "no such column")
}
return false
return err != nil && strings.Contains(err.Error(), "no such column")
}
func (db *database) getAllChats() ([]Chat, error) {
@@ -1108,9 +1122,9 @@ func (db *database) getSettings() (Settings, error) {
var s Settings
err := db.conn.QueryRow(`
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
FROM settings
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
if err != nil {
return Settings{}, fmt.Errorf("get settings: %w", err)
}
@@ -1121,8 +1135,8 @@ func (db *database) getSettings() (Settings, error) {
func (db *database) setSettings(s Settings) error {
_, err := db.conn.Exec(`
UPDATE settings
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
if err != nil {
return fmt.Errorf("set settings: %w", err)
}

View File

@@ -169,6 +169,9 @@ type Settings struct {
// SidebarOpen indicates if the chat sidebar is open
SidebarOpen bool
// AutoUpdateEnabled indicates if automatic updates should be downloaded
AutoUpdateEnabled bool
}
type Store struct {

View File

@@ -413,6 +413,7 @@ export class Settings {
ThinkLevel: string;
SelectedModel: string;
SidebarOpen: boolean;
AutoUpdateEnabled: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
@@ -431,6 +432,7 @@ export class Settings {
this.ThinkLevel = source["ThinkLevel"];
this.SelectedModel = source["SelectedModel"];
this.SidebarOpen = source["SidebarOpen"];
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
}
}
export class SettingsResponse {
@@ -467,6 +469,46 @@ export class HealthResponse {
this.healthy = source["healthy"];
}
}
export class UpdateInfo {
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.currentVersion = source["currentVersion"];
this.availableVersion = source["availableVersion"];
this.updateAvailable = source["updateAvailable"];
this.updateDownloaded = source["updateDownloaded"];
}
}
export class UpdateCheckResponse {
updateInfo: UpdateInfo;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
}
convertValues(a: any, classs: any, asMap: boolean = false): any {
if (!a) {
return a;
}
if (Array.isArray(a)) {
return (a as any[]).map(elem => this.convertValues(elem, classs));
} else if ("object" === typeof a) {
if (asMap) {
for (const key of Object.keys(a)) {
a[key] = new classs(a[key]);
}
return a;
}
return new classs(a);
}
return a;
}
}
export class User {
id: string;
email: string;

View File

@@ -414,3 +414,54 @@ export async function fetchHealth(): Promise<boolean> {
return false;
}
}
export async function getCurrentVersion(): Promise<string> {
try {
const response = await fetch(`${API_BASE}/api/version`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const data = await response.json();
return data.version || "Unknown";
}
return "Unknown";
} catch (error) {
console.error("Error fetching version:", error);
return "Unknown";
}
}
export async function checkForUpdate(): Promise<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
}> {
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error("Failed to check for update");
}
const data = await response.json();
return data.updateInfo;
}
export async function installUpdate(): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
const error = await response.text();
throw new Error(error || "Failed to install update");
}
}

View File

@@ -14,12 +14,13 @@ import {
XMarkIcon,
CogIcon,
ArrowLeftIcon,
ArrowDownTrayIcon,
} from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router";
import { useUser } from "@/hooks/useUser";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { getSettings, updateSettings } from "@/api";
import { getSettings, updateSettings, checkForUpdate } from "@/api";
function AnimatedDots() {
return (
@@ -39,6 +40,12 @@ export default function Settings() {
const queryClient = useQueryClient();
const [showSaved, setShowSaved] = useState(false);
const [restartMessage, setRestartMessage] = useState(false);
const [updateInfo, setUpdateInfo] = useState<{
currentVersion: string;
availableVersion: string;
updateAvailable: boolean;
updateDownloaded: boolean;
} | null>(null);
const {
user,
isAuthenticated,
@@ -76,8 +83,22 @@ export default function Settings() {
useEffect(() => {
refetchUser();
// Check for updates
checkForUpdate()
.then(setUpdateInfo)
.catch((err) => console.error("Error checking for update:", err));
}, []); // eslint-disable-line react-hooks/exhaustive-deps
// Refresh update info when auto-update toggle changes
useEffect(() => {
if (settings?.AutoUpdateEnabled !== undefined) {
checkForUpdate()
.then(setUpdateInfo)
.catch((err) => console.error("Error checking for update:", err));
}
}, [settings?.AutoUpdateEnabled]);
useEffect(() => {
const handleFocus = () => {
if (isAwaitingConnection && pollingInterval) {
@@ -344,6 +365,58 @@ export default function Settings() {
{/* Local Configuration */}
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
<div className="space-y-4 p-4">
{/* Auto Update */}
<Field>
<div className="flex items-start justify-between gap-4">
<div className="flex items-start space-x-3 flex-1">
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
<div className="flex-1">
<Label>Auto-download updates</Label>
<Description>
{settings.AutoUpdateEnabled ? (
<>
Automatically downloads updates when available.
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
Current version: {updateInfo?.currentVersion || "Loading..."}
</div>
</>
) : (
<>
Manually download updates.
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
<div className="space-y-2 text-sm">
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
</div>
{updateInfo?.availableVersion && (
<div className="flex justify-between">
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
</div>
)}
</div>
<a
href="https://ollama.com/download"
target="_blank"
rel="noopener noreferrer"
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
>
Download new version
</a>
</div>
</>
)}
</Description>
</div>
</div>
<div className="flex-shrink-0">
<Switch
checked={settings.AutoUpdateEnabled}
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
/>
</div>
</div>
</Field>
{/* Expose Ollama */}
<Field>
<div className="flex items-start justify-between gap-4">

View File

@@ -147,6 +147,7 @@ export const highlighterPromise = createHighlighter({
"c",
"cpp",
"sql",
"swift",
"yaml",
"markdown",
],

View File

@@ -100,6 +100,17 @@ type HealthResponse struct {
Healthy bool `json:"healthy"`
}
type UpdateInfo struct {
CurrentVersion string `json:"currentVersion"`
AvailableVersion string `json:"availableVersion"`
UpdateAvailable bool `json:"updateAvailable"`
UpdateDownloaded bool `json:"updateDownloaded"`
}
type UpdateCheckResponse struct {
UpdateInfo UpdateInfo `json:"updateInfo"`
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`

View File

@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/app/tools"
"github.com/ollama/ollama/app/types/not"
"github.com/ollama/ollama/app/ui/responses"
"github.com/ollama/ollama/app/updater"
"github.com/ollama/ollama/app/version"
ollamaAuth "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
@@ -106,6 +107,18 @@ type Server struct {
// Dev is true if the server is running in development mode
Dev bool
// Updater for checking and downloading updates
Updater UpdaterInterface
UpdateAvailableFunc func()
}
// UpdaterInterface defines the methods we need from the updater
type UpdaterInterface interface {
CheckForUpdate(ctx context.Context) (bool, string, error)
InstallAndRestart() error
CancelOngoingDownload()
TriggerImmediateCheck()
}
func (s *Server) log() *slog.Logger {
@@ -284,6 +297,8 @@ func (s *Server) Handler() http.Handler {
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings))
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
// Ollama proxy endpoints
ollamaProxy := s.ollamaProxy()
@@ -1448,6 +1463,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
return fmt.Errorf("failed to save settings: %w", err)
}
// Handle auto-update toggle changes
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
if !settings.AutoUpdateEnabled {
// Auto-update disabled: cancel any ongoing download
if s.Updater != nil {
s.Updater.CancelOngoingDownload()
}
} else {
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
s.UpdateAvailableFunc()
} else if s.Updater != nil {
// Trigger the background checker to run immediately
s.Updater.TriggerImmediateCheck()
}
}
}
if old.ContextLength != settings.ContextLength ||
old.Models != settings.Models ||
old.Expose != settings.Expose {
@@ -1524,6 +1557,73 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
return json.NewEncoder(w).Encode(response)
}
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
currentVersion := version.Version
if s.Updater == nil {
return fmt.Errorf("updater not available")
}
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
if err != nil {
s.log().Warn("failed to check for update", "error", err)
// Don't return error, just log it and continue with no update available
}
response := responses.UpdateCheckResponse{
UpdateInfo: responses.UpdateInfo{
CurrentVersion: currentVersion,
AvailableVersion: updateVersion,
UpdateAvailable: updateAvailable,
UpdateDownloaded: updater.UpdateDownloaded,
},
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return fmt.Errorf("method not allowed")
}
if s.Updater == nil {
s.log().Error("install failed: updater not available")
return fmt.Errorf("updater not available")
}
// Check if update is downloaded
if !updater.UpdateDownloaded {
s.log().Error("install failed: no update downloaded")
return fmt.Errorf("no update downloaded")
}
// Send response before restarting
response := map[string]any{
"success": true,
"message": "Installing update and restarting...",
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(response); err != nil {
return err
}
// Give the response time to be sent
time.Sleep(500 * time.Millisecond)
// Trigger the upgrade and restart
go func() {
time.Sleep(500 * time.Millisecond)
if err := s.Updater.InstallAndRestart(); err != nil {
s.log().Error("failed to install update", "error", err)
}
}()
return nil
}
func userAgent() string {
buildinfo, _ := debug.ReadBuildInfo()

View File

@@ -19,6 +19,7 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"
"github.com/ollama/ollama/app/store"
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
query := requestURL.Query()
query.Add("os", runtime.GOOS)
query.Add("arch", runtime.GOARCH)
query.Add("version", version.Version)
currentVersion := version.Version
query.Add("version", currentVersion)
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
// The original macOS app used to use the device ID
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
}
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
// Create a cancellable context for this download
downloadCtx, cancel := context.WithCancel(ctx)
u.cancelDownloadLock.Lock()
u.cancelDownload = cancel
u.cancelDownloadLock.Unlock()
defer func() {
u.cancelDownloadLock.Lock()
u.cancelDownload = nil
u.cancelDownloadLock.Unlock()
cancel()
}()
// Do a head first to check etag info
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
if err != nil {
return err
}
// In case of slow downloads, continue the update check in the background
bgctx, cancel := context.WithCancel(ctx)
defer cancel()
bgctx, bgcancel := context.WithCancel(downloadCtx)
defer bgcancel()
go func() {
for {
select {
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
_, err = os.Stat(stageFilename)
if err == nil {
slog.Info("update already downloaded", "bundle", stageFilename)
UpdateDownloaded = true
return nil
}
@@ -244,34 +259,95 @@ func cleanupOldDownloads(stageDir string) {
}
type Updater struct {
Store *store.Store
Store *store.Store
cancelDownload context.CancelFunc
cancelDownloadLock sync.Mutex
checkNow chan struct{}
}
// CancelOngoingDownload cancels any currently running download
func (u *Updater) CancelOngoingDownload() {
u.cancelDownloadLock.Lock()
defer u.cancelDownloadLock.Unlock()
if u.cancelDownload != nil {
slog.Info("cancelling ongoing update download")
u.cancelDownload()
u.cancelDownload = nil
}
}
// TriggerImmediateCheck signals the background checker to check for updates immediately
func (u *Updater) TriggerImmediateCheck() {
if u.checkNow != nil {
u.checkNow <- struct{}{}
}
}
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
u.checkNow = make(chan struct{}, 1)
go func() {
// Don't blast an update message immediately after startup
time.Sleep(UpdateCheckInitialDelay)
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
ticker := time.NewTicker(UpdateCheckInterval)
defer ticker.Stop()
for {
available, resp := u.checkForUpdate(ctx)
if available {
err := u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
} else {
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
}
}
}
select {
case <-ctx.Done():
slog.Debug("stopping background update checker")
return
default:
time.Sleep(UpdateCheckInterval)
case <-u.checkNow:
// Immediate check triggered
case <-ticker.C:
// Regular interval check
}
// Always check for updates
available, resp := u.checkForUpdate(ctx)
if !available {
continue
}
// Update is available - check if auto-update is enabled for downloading
settings, err := u.Store.Settings()
if err != nil {
slog.Error("failed to load settings", "error", err)
continue
}
if !settings.AutoUpdateEnabled {
// Auto-update disabled - don't download, just log
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
continue
}
// Auto-update is enabled - download
err = u.DownloadNewRelease(ctx, resp)
if err != nil {
slog.Error("failed to download new release", "error", err)
continue
}
// Download successful - show tray notification (regardless of toggle state)
err = cb(resp.UpdateVersion)
if err != nil {
slog.Warn("failed to register update available with tray", "error", err)
}
}
}()
}
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
available, resp := u.checkForUpdate(ctx)
return available, resp.UpdateVersion, nil
}
func (u *Updater) InstallAndRestart() error {
if !UpdateDownloaded {
return fmt.Errorf("no update downloaded")
}
slog.Info("installing update and restarting")
return DoUpgrade(true)
}

View File

@@ -85,7 +85,17 @@ func TestBackgoundChecker(t *testing.T) {
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed
defer updater.Store.Close()
settings, err := updater.Store.Settings()
if err != nil {
t.Fatal(err)
}
settings.AutoUpdateEnabled = true
if err := updater.Store.SetSettings(settings); err != nil {
t.Fatal(err)
}
updater.StartBackgroundUpdaterChecker(ctx, cb)
select {
case <-stallTimer.C:

View File

@@ -369,24 +369,24 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
return nil
}
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
// const ERROR_SUCCESS syscall.Errno = 0
func (t *winTray) removeMenuItem(menuItemId, parentId uint32) error {
const ERROR_SUCCESS syscall.Errno = 0
// t.muMenus.RLock()
// menu := uintptr(t.menus[parentId])
// t.muMenus.RUnlock()
// res, _, err := pRemoveMenu.Call(
// menu,
// uintptr(menuItemId),
// MF_BYCOMMAND,
// )
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
// return err
// }
// t.delFromVisibleItems(parentId, menuItemId)
t.muMenus.RLock()
menu := uintptr(t.menus[parentId])
t.muMenus.RUnlock()
res, _, err := pRemoveMenu.Call(
menu,
uintptr(menuItemId),
MF_BYCOMMAND,
)
if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
return err
}
t.delFromVisibleItems(parentId, menuItemId)
// return nil
// }
return nil
}
func (t *winTray) showMenu() error {
p := point{}

View File

@@ -30,6 +30,7 @@ var (
pPostQuitMessage = u32.NewProc("PostQuitMessage")
pRegisterClass = u32.NewProc("RegisterClassExW")
pRegisterWindowMessage = u32.NewProc("RegisterWindowMessageW")
pRemoveMenu = u32.NewProc("RemoveMenu")
pSendMessage = u32.NewProc("SendMessageW")
pSetForegroundWindow = u32.NewProc("SetForegroundWindow")
pSetMenuInfo = u32.NewProc("SetMenuInfo")

View File

@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
"tool_calls": [
{
"function": {
"name": "get_temperature",
"name": "get_weather",
"arguments": {
"city": "Toronto"
}
},
}
}
]
},
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_temperature",
"tool_name": "get_weather"
}
],
"stream": false,

View File

@@ -277,6 +277,8 @@ curl -X POST http://localhost:11434/v1/chat/completions \
### `/v1/responses`
> Note: Added in Ollama v0.13.3
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
#### Supported features

View File

@@ -36,7 +36,6 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
}],
"stream": false
}'
"
```
</Tab>
<Tab title="Python">

View File

@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu.md).
Please refer to the [GPU docs](./gpu).
## How can I specify the context window size?

View File

@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
### GPU Selection
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library:
> [!NOTE]
> **NOTE:**
> Additional AMD GPU support is provided by the Vulkan Library - see below.
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> [!NOTE]
> **NOTE:**
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
described in the [FAQ](faq#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
### Linux NVIDIA Troubleshooting
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem

View File

@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 976 insertions(+), 17 deletions(-)
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++-
ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 ++++++++++
9 files changed, 1005 insertions(+), 17 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
@@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644
set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 6852d2e20..48cdb1dcf 100644
index 6852d2e20..334a30135 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -109,7 +109,7 @@ index 6852d2e20..48cdb1dcf 100644
+
+#if defined(GGML_USE_HIP)
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
@@ -204,7 +204,7 @@ index 4e162258d..d89e35a8e 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index fe57d4c58..1c07e767a 100644
index fe57d4c58..dba8f4695 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
@@ -216,7 +216,7 @@ index fe57d4c58..1c07e767a 100644
+GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
+GGML_API void ggml_nvml_release();
+GGML_API int ggml_hip_mgmt_init();
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
+GGML_API void ggml_hip_mgmt_release();
+
#ifdef __cplusplus
@@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644
/* .async = */ true,
/* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 5349bce24..d43d46d1d 100644
index 5349bce24..0103fd03a 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -236,6 +236,7 @@ class vk_memory_logger;
@@ -334,7 +334,7 @@ index 5349bce24..d43d46d1d 100644
+ switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
+ if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ ggml_hip_mgmt_release();
@@ -505,10 +505,10 @@ index 5349bce24..d43d46d1d 100644
}
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 000000000..c1949b899
index 000000000..23c765806
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,529 @@
@@ -0,0 +1,558 @@
+#include "ggml.h"
+#include "ggml-impl.h"
+
@@ -842,7 +842,7 @@ index 000000000..c1949b899
+ if (gpus != NULL) gpus->pVtbl->Release(gpus); \
+ if (gpu != NULL) gpu->pVtbl->Release(gpu)
+
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
+ std::lock_guard<std::mutex> lock(ggml_adlx_lock);
+ if (adlx.handle == NULL) {
+ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@@ -966,13 +966,16 @@ index 000000000..c1949b899
+ return 0;
+}
+void ggml_hip_mgmt_release() {}
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
+ const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
+ const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
+
+
+ glob_t glob_result;
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
+
@@ -1006,7 +1009,6 @@ index 000000000..c1949b899
+
+ uint64_t memory;
+ totalFileStream >> memory;
+ *total = memory;
+
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
@@ -1019,6 +1021,33 @@ index 000000000..c1949b899
+
+ uint64_t memoryUsed;
+ usedFileStream >> memoryUsed;
+
+ if (is_integrated_gpu) {
+ std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
+ std::ifstream totalFileStream(totalFile.c_str());
+ if (!totalFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gtt;
+ totalFileStream >> gtt;
+ std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
+ if (!usedFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gttUsed;
+ usedFileStream >> gttUsed;
+ memory += gtt;
+ memoryUsed += gttUsed;
+ }
+
+ *total = memory;
+ *free = memory - memoryUsed;
+
+ file.close();

View File

@@ -24,12 +24,12 @@ index 99ae293cc..9a134b7af 100644
set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 1c07e767a..0da3e065b 100644
index dba8f4695..7e17032c7 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
GGML_API void ggml_hip_mgmt_release();
+GGML_API int ggml_dxgi_pdh_init();
+GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
@@ -38,7 +38,7 @@ index 1c07e767a..0da3e065b 100644
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index d43d46d1d..df79f9f79 100644
index 0103fd03a..9cc4ebdef 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@@ -10,7 +10,7 @@ fallback to cpu
1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 48cdb1dcf..3102d7ea7 100644
index 334a30135..5c9dfd032 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g

View File

@@ -4436,7 +4436,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
#if defined(GGML_USE_HIP)
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
ggml_hip_mgmt_release();

View File

@@ -682,7 +682,7 @@ GGML_API int ggml_nvml_init();
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
GGML_API void ggml_hip_mgmt_release();
GGML_API int ggml_dxgi_pdh_init();
GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);

View File

@@ -13710,7 +13710,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
switch (props2.properties.vendorID) {
case VK_VENDOR_ID_AMD:
if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
ggml_hip_mgmt_release();

View File

@@ -331,7 +331,7 @@ void ggml_hip_mgmt_release() {
if (gpus != NULL) gpus->pVtbl->Release(gpus); \
if (gpu != NULL) gpu->pVtbl->Release(gpu)
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
std::lock_guard<std::mutex> lock(ggml_adlx_lock);
if (adlx.handle == NULL) {
GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@@ -455,13 +455,16 @@ int ggml_hip_mgmt_init() {
return 0;
}
void ggml_hip_mgmt_release() {}
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
const std::string drmTotalMemoryFile = "mem_info_vram_total";
const std::string drmUsedMemoryFile = "mem_info_vram_used";
const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
glob_t glob_result;
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
@@ -495,7 +498,6 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
uint64_t memory;
totalFileStream >> memory;
*total = memory;
std::string usedFile = dir + "/" + drmUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
@@ -508,6 +510,33 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
uint64_t memoryUsed;
usedFileStream >> memoryUsed;
if (is_integrated_gpu) {
std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
std::ifstream totalFileStream(totalFile.c_str());
if (!totalFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gtt;
totalFileStream >> gtt;
std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
if (!usedFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gttUsed;
usedFileStream >> gttUsed;
memory += gtt;
memoryUsed += gttUsed;
}
*total = memory;
*free = memory - memoryUsed;
file.close();

View File

@@ -752,9 +752,15 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return err
}
// TODO: this first normalization should be done by the model
embedding = normalize(embedding)
embedding, err = normalize(embedding)
if err != nil {
return err
}
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
embedding = normalize(embedding[:req.Dimensions])
embedding, err = normalize(embedding[:req.Dimensions])
if err != nil {
return err
}
}
embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount))
@@ -787,9 +793,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
func normalize(vec []float32) ([]float32, error) {
var sum float32
for _, v := range vec {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
return nil, errors.New("embedding contains NaN or Inf values")
}
sum += v * v
}
@@ -797,7 +806,7 @@ func normalize(vec []float32) []float32 {
for i := range vec {
vec[i] *= norm
}
return vec
return vec, nil
}
func (s *Server) EmbeddingsHandler(c *gin.Context) {
@@ -2395,4 +2404,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
}
return msgs
}

View File

@@ -723,15 +723,20 @@ func TestShow(t *testing.T) {
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
input []float32
expectError bool
}
testCases := []testCase{
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
{input: []float32{1}, expectError: false},
{input: []float32{0, 1, 2, 3}, expectError: false},
{input: []float32{0.1, 0.2, 0.3}, expectError: false},
{input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false},
{input: []float32{0, 0, 0}, expectError: false},
{input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true},
{input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true},
{input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true},
{input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true},
}
isNormalized := func(vec []float32) (res bool) {
@@ -748,9 +753,18 @@ func TestNormalize(t *testing.T) {
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := normalize(tc.input)
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
normalized, err := normalize(tc.input)
if tc.expectError {
if err == nil {
t.Errorf("Expected error for input %v, but got none", tc.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %v: %v", tc.input, err)
}
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
}
})
}