Compare commits
20 Commits
jmorganca/
...
hoyyeva/up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68a3414761 | ||
|
|
9a5c14c58b | ||
|
|
391fb88bce | ||
|
|
75500c8855 | ||
|
|
e55fbf2475 | ||
|
|
c6f941adb3 | ||
|
|
0eb320e74c | ||
|
|
880b4f95b4 | ||
|
|
ba25f4a898 | ||
|
|
dc573715c4 | ||
|
|
5a5d3260f4 | ||
|
|
cf7e5e88bc | ||
|
|
e76abac24e | ||
|
|
d087e46bd1 | ||
|
|
37f6f3af24 | ||
|
|
e1bdc23dd2 | ||
|
|
2e78653ff9 | ||
|
|
f5f74e12c1 | ||
|
|
18fdcc94e5 | ||
|
|
7ad036992f |
@@ -209,9 +209,6 @@ func main() {
|
||||
|
||||
st := &store.Store{}
|
||||
|
||||
// Initialize native settings with store
|
||||
SetSettingsStore(st)
|
||||
|
||||
// Enable CORS in development mode
|
||||
if devMode {
|
||||
os.Setenv("OLLAMA_CORS", "1")
|
||||
@@ -256,27 +253,28 @@ func main() {
|
||||
done <- osrv.Run(octx)
|
||||
}()
|
||||
|
||||
restartServer := func() {
|
||||
ocancel()
|
||||
<-done
|
||||
octx, ocancel = context.WithCancel(ctx)
|
||||
go func() {
|
||||
done <- osrv.Run(octx)
|
||||
}()
|
||||
}
|
||||
upd := &updater.Updater{Store: st}
|
||||
|
||||
uiServer := ui.Server{
|
||||
Token: token,
|
||||
Restart: restartServer,
|
||||
Token: token,
|
||||
Restart: func() {
|
||||
ocancel()
|
||||
<-done
|
||||
octx, ocancel = context.WithCancel(ctx)
|
||||
go func() {
|
||||
done <- osrv.Run(octx)
|
||||
}()
|
||||
},
|
||||
Store: st,
|
||||
ToolRegistry: toolRegistry,
|
||||
Dev: devMode,
|
||||
Logger: slog.Default(),
|
||||
Updater: upd,
|
||||
UpdateAvailableFunc: func() {
|
||||
UpdateAvailable("")
|
||||
},
|
||||
}
|
||||
|
||||
// Set restart callback for native settings
|
||||
SetRestartCallback(restartServer)
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: uiServer.Handler(),
|
||||
}
|
||||
@@ -292,8 +290,13 @@ func main() {
|
||||
slog.Debug("background desktop server done")
|
||||
}()
|
||||
|
||||
updater := &updater.Updater{Store: st}
|
||||
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||
|
||||
// Check for pending updates on startup (show tray notification if update is ready)
|
||||
if updater.IsUpdatePending() {
|
||||
slog.Debug("update pending on startup, showing tray notification")
|
||||
UpdateAvailable("")
|
||||
}
|
||||
|
||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||
if err != nil {
|
||||
@@ -356,6 +359,18 @@ func startHiddenTasks() {
|
||||
// CLI triggered app startup use-case
|
||||
slog.Info("deferring pending update for fast startup")
|
||||
} else {
|
||||
// Check if auto-update is enabled before automatically upgrading
|
||||
st := &store.Store{}
|
||||
settings, err := st.Settings()
|
||||
if err != nil {
|
||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
||||
} else if !settings.AutoUpdateEnabled {
|
||||
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
||||
// Still show tray notification so user knows update is ready
|
||||
UpdateAvailable("")
|
||||
return
|
||||
}
|
||||
|
||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#import "app_darwin.h"
|
||||
#import "menu.h"
|
||||
#import "settings_darwin.h"
|
||||
#import "../../updater/updater_darwin.h"
|
||||
#import <AppKit/AppKit.h>
|
||||
#import <Cocoa/Cocoa.h>
|
||||
@@ -253,7 +252,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
|
||||
}
|
||||
|
||||
- (void)settingsUI {
|
||||
openNativeSettings();
|
||||
[self uiRequest:@"/settings"];
|
||||
}
|
||||
|
||||
- (void)openUI {
|
||||
|
||||
@@ -157,6 +157,10 @@ func UpdateAvailable(ver string) error {
|
||||
return app.t.UpdateAvailable(ver)
|
||||
}
|
||||
|
||||
func ClearUpdateAvailable() error {
|
||||
return app.t.ClearUpdateAvailable()
|
||||
}
|
||||
|
||||
func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
||||
var err error
|
||||
app.shutdown = shutdown
|
||||
|
||||
@@ -1,438 +0,0 @@
|
||||
//go:build darwin
|
||||
|
||||
package main
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -x objective-c
|
||||
#cgo LDFLAGS: -framework Cocoa
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "settings_darwin.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
appauth "github.com/ollama/ollama/app/auth"
|
||||
"github.com/ollama/ollama/app/store"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// settingsStore is a reference to the app's store for settings
|
||||
var settingsStore *store.Store
|
||||
|
||||
// SetSettingsStore sets the store reference for settings callbacks
|
||||
func SetSettingsStore(s *store.Store) {
|
||||
settingsStore = s
|
||||
}
|
||||
|
||||
//export getSettingsExpose
|
||||
func getSettingsExpose() C.bool {
|
||||
if settingsStore == nil {
|
||||
return C.bool(false)
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return C.bool(false)
|
||||
}
|
||||
return C.bool(settings.Expose)
|
||||
}
|
||||
|
||||
//export setSettingsExpose
|
||||
func setSettingsExpose(expose C.bool) {
|
||||
if settingsStore == nil {
|
||||
return
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return
|
||||
}
|
||||
settings.Expose = bool(expose)
|
||||
if err := settingsStore.SetSettings(settings); err != nil {
|
||||
slog.Error("failed to save settings", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
//export getSettingsBrowser
|
||||
func getSettingsBrowser() C.bool {
|
||||
if settingsStore == nil {
|
||||
return C.bool(false)
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return C.bool(false)
|
||||
}
|
||||
return C.bool(settings.Browser)
|
||||
}
|
||||
|
||||
//export setSettingsBrowser
|
||||
func setSettingsBrowser(browser C.bool) {
|
||||
if settingsStore == nil {
|
||||
return
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return
|
||||
}
|
||||
settings.Browser = bool(browser)
|
||||
if err := settingsStore.SetSettings(settings); err != nil {
|
||||
slog.Error("failed to save settings", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
//export getSettingsModels
|
||||
func getSettingsModels() *C.char {
|
||||
if settingsStore == nil {
|
||||
return C.CString(envconfig.Models())
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return C.CString(envconfig.Models())
|
||||
}
|
||||
if settings.Models == "" {
|
||||
return C.CString(envconfig.Models())
|
||||
}
|
||||
return C.CString(settings.Models)
|
||||
}
|
||||
|
||||
//export setSettingsModels
|
||||
func setSettingsModels(path *C.char) {
|
||||
if settingsStore == nil {
|
||||
return
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return
|
||||
}
|
||||
settings.Models = C.GoString(path)
|
||||
if err := settingsStore.SetSettings(settings); err != nil {
|
||||
slog.Error("failed to save settings", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
//export getSettingsContextLength
|
||||
func getSettingsContextLength() C.int {
|
||||
if settingsStore == nil {
|
||||
return C.int(4096)
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return C.int(4096)
|
||||
}
|
||||
if settings.ContextLength <= 0 {
|
||||
return C.int(4096)
|
||||
}
|
||||
return C.int(settings.ContextLength)
|
||||
}
|
||||
|
||||
//export setSettingsContextLength
|
||||
func setSettingsContextLength(length C.int) {
|
||||
if settingsStore == nil {
|
||||
return
|
||||
}
|
||||
settings, err := settingsStore.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to get settings", "error", err)
|
||||
return
|
||||
}
|
||||
settings.ContextLength = int(length)
|
||||
if err := settingsStore.SetSettings(settings); err != nil {
|
||||
slog.Error("failed to save settings", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// restartCallback is set by the app to restart the ollama server
|
||||
var restartCallback func()
|
||||
|
||||
// SetRestartCallback sets the function to call when settings change requires a restart
|
||||
func SetRestartCallback(cb func()) {
|
||||
restartCallback = cb
|
||||
}
|
||||
|
||||
//export restartOllamaServer
|
||||
func restartOllamaServer() {
|
||||
if restartCallback != nil {
|
||||
slog.Info("restarting ollama server due to settings change")
|
||||
go restartCallback()
|
||||
}
|
||||
}
|
||||
|
||||
// hasOllamaKey checks if the user has an Ollama key file
|
||||
func hasOllamaKey() bool {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||
_, err = os.Stat(keyPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ensureKeypair generates a new keypair if one doesn't exist
|
||||
func ensureKeypair() error {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
privKeyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||
|
||||
// Check if key already exists
|
||||
if _, err := os.Stat(privKeyPath); err == nil {
|
||||
return nil // Key exists
|
||||
}
|
||||
|
||||
// Generate new keypair
|
||||
slog.Info("generating new keypair for ollama account")
|
||||
|
||||
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key: %w", err)
|
||||
}
|
||||
|
||||
// Marshal private key
|
||||
privKeyBytes, err := ssh.MarshalPrivateKey(privKey, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal private key: %w", err)
|
||||
}
|
||||
|
||||
// Ensure directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(privKeyPath), 0o755); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %w", err)
|
||||
}
|
||||
|
||||
// Write private key
|
||||
if err := os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600); err != nil {
|
||||
return fmt.Errorf("failed to write private key: %w", err)
|
||||
}
|
||||
|
||||
// Write public key
|
||||
sshPubKey, err := ssh.NewPublicKey(pubKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create ssh public key: %w", err)
|
||||
}
|
||||
pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
|
||||
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
||||
if err := os.WriteFile(pubKeyPath, pubKeyBytes, 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write public key: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("keypair generated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// userResponse matches the API response from ollama.com/api/me
|
||||
type userResponse struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Plan string `json:"plan"`
|
||||
AvatarURL string `json:"avatarurl"`
|
||||
}
|
||||
|
||||
// fetchUserFromAPI fetches user data from ollama.com using signed request
|
||||
func fetchUserFromAPI() (*userResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
signString := fmt.Sprintf("POST,/api/me?ts=%s", timestamp)
|
||||
signature, err := auth.Sign(ctx, []byte(signString))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("https://ollama.com/api/me?ts=%s", timestamp)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signature))
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call ollama.com: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var user userResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// Make avatar URL absolute
|
||||
if user.AvatarURL != "" && !strings.HasPrefix(user.AvatarURL, "http") {
|
||||
user.AvatarURL = "https://ollama.com/" + user.AvatarURL
|
||||
}
|
||||
|
||||
// Cache the avatar URL
|
||||
cachedAvatarURL = user.AvatarURL
|
||||
|
||||
// Cache the user data
|
||||
if settingsStore != nil {
|
||||
storeUser := store.User{
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Plan: user.Plan,
|
||||
}
|
||||
if err := settingsStore.SetUser(storeUser); err != nil {
|
||||
slog.Warn("failed to cache user", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
//export getAccountName
|
||||
func getAccountName() *C.char {
|
||||
// Only return cached data - never block on network
|
||||
if settingsStore == nil {
|
||||
return C.CString("")
|
||||
}
|
||||
user, err := settingsStore.User()
|
||||
if err != nil || user == nil {
|
||||
return C.CString("")
|
||||
}
|
||||
return C.CString(user.Name)
|
||||
}
|
||||
|
||||
// cachedAvatarURL stores the avatar URL from the last API fetch
|
||||
var cachedAvatarURL string
|
||||
|
||||
//export getAccountAvatarURL
|
||||
func getAccountAvatarURL() *C.char {
|
||||
return C.CString(cachedAvatarURL)
|
||||
}
|
||||
|
||||
//export getAccountEmail
|
||||
func getAccountEmail() *C.char {
|
||||
if settingsStore != nil {
|
||||
user, err := settingsStore.User()
|
||||
if err == nil && user != nil {
|
||||
return C.CString(user.Email)
|
||||
}
|
||||
}
|
||||
return C.CString("")
|
||||
}
|
||||
|
||||
//export getAccountPlan
|
||||
func getAccountPlan() *C.char {
|
||||
if settingsStore != nil {
|
||||
user, err := settingsStore.User()
|
||||
if err == nil && user != nil {
|
||||
return C.CString(user.Plan)
|
||||
}
|
||||
}
|
||||
return C.CString("")
|
||||
}
|
||||
|
||||
//export signOutAccount
|
||||
func signOutAccount() {
|
||||
if settingsStore != nil {
|
||||
if err := settingsStore.ClearUser(); err != nil {
|
||||
slog.Error("failed to clear user", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Also remove the key file
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
slog.Error("failed to get home dir", "error", err)
|
||||
return
|
||||
}
|
||||
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||
if err := os.Remove(keyPath); err != nil && !os.IsNotExist(err) {
|
||||
slog.Error("failed to remove key file", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
//export openConnectUrl
|
||||
func openConnectUrl() {
|
||||
// Ensure keypair exists (generate if needed)
|
||||
if err := ensureKeypair(); err != nil {
|
||||
slog.Error("failed to ensure keypair", "error", err)
|
||||
// Fallback to basic connect page
|
||||
cmd := exec.Command("open", "https://ollama.com/connect")
|
||||
cmd.Start()
|
||||
return
|
||||
}
|
||||
|
||||
// Build connect URL with public key
|
||||
connectURL, err := appauth.BuildConnectURL("https://ollama.com")
|
||||
if err != nil {
|
||||
slog.Error("failed to build connect URL", "error", err)
|
||||
// Fallback to basic connect page
|
||||
connectURL = "https://ollama.com/connect"
|
||||
}
|
||||
|
||||
cmd := exec.Command("open", connectURL)
|
||||
if err := cmd.Start(); err != nil {
|
||||
slog.Error("failed to open connect URL", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
//export refreshAccountFromAPI
|
||||
func refreshAccountFromAPI() {
|
||||
if !hasOllamaKey() {
|
||||
return
|
||||
}
|
||||
_, err := fetchUserFromAPI()
|
||||
if err != nil {
|
||||
slog.Debug("failed to refresh account", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
//export prefetchAccountData
|
||||
func prefetchAccountData() {
|
||||
// Run in background goroutine to not block app startup
|
||||
go func() {
|
||||
if !hasOllamaKey() {
|
||||
return
|
||||
}
|
||||
_, err := fetchUserFromAPI()
|
||||
if err != nil {
|
||||
slog.Debug("failed to prefetch account data", "error", err)
|
||||
} else {
|
||||
slog.Debug("prefetched account data successfully")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// OpenNativeSettings opens the native settings window
|
||||
func OpenNativeSettings() {
|
||||
C.openNativeSettings()
|
||||
}
|
||||
|
||||
// Ensure the CString is freed (caller must free)
|
||||
func freeCString(s *C.char) {
|
||||
C.free(unsafe.Pointer(s))
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
#import <Cocoa/Cocoa.h>
|
||||
|
||||
@interface SettingsWindowController : NSWindowController <NSWindowDelegate>
|
||||
|
||||
// General tab
|
||||
@property(nonatomic, strong) NSButton *exposeCheckbox;
|
||||
@property(nonatomic, strong) NSButton *browserCheckbox;
|
||||
@property(nonatomic, strong) NSSlider *contextLengthSlider;
|
||||
|
||||
// Models tab
|
||||
@property(nonatomic, strong) NSPathControl *modelsPathControl;
|
||||
@property(nonatomic, strong) NSButton *modelsPathButton;
|
||||
|
||||
// Account tab
|
||||
@property(nonatomic, strong) NSView *avatarView;
|
||||
@property(nonatomic, strong) NSTextField *avatarInitialLabel;
|
||||
@property(nonatomic, strong) NSImageView *avatarImageView;
|
||||
@property(nonatomic, strong) NSTextField *accountNameLabel;
|
||||
@property(nonatomic, strong) NSTextField *accountEmailLabel;
|
||||
@property(nonatomic, strong) NSButton *manageButton;
|
||||
@property(nonatomic, strong) NSButton *signOutButton;
|
||||
@property(nonatomic, strong) NSButton *signInButton;
|
||||
@property(nonatomic, strong) NSView *signedInContainer;
|
||||
@property(nonatomic, strong) NSView *signedOutContainer;
|
||||
|
||||
// Plan section
|
||||
@property(nonatomic, strong) NSView *planContainer;
|
||||
@property(nonatomic, strong) NSTextField *planNameLabel;
|
||||
@property(nonatomic, strong) NSButton *upgradeButton;
|
||||
@property(nonatomic, strong) NSButton *viewUsageButton;
|
||||
|
||||
+ (instancetype)sharedController;
|
||||
- (void)showSettings;
|
||||
|
||||
@end
|
||||
|
||||
// Go callbacks for settings
|
||||
void openNativeSettings(void);
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,16 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/app/store"
|
||||
|
||||
// SetSettingsStore sets the store reference for settings callbacks (stub for Windows)
|
||||
func SetSettingsStore(s *store.Store) {
|
||||
// TODO: Implement Windows native settings
|
||||
}
|
||||
|
||||
// SetRestartCallback sets the function to call when settings change requires a restart (stub for Windows)
|
||||
func SetRestartCallback(cb func()) {
|
||||
// TODO: Implement Windows native settings
|
||||
}
|
||||
|
||||
@@ -9,12 +9,12 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// currentSchemaVersion defines the current database schema version.
|
||||
// Increment this when making schema changes that require migrations.
|
||||
const currentSchemaVersion = 12
|
||||
const currentSchemaVersion = 13
|
||||
|
||||
// database wraps the SQLite connection.
|
||||
// SQLite handles its own locking for concurrent access:
|
||||
@@ -85,6 +85,7 @@ func (db *database) init() error {
|
||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
think_level TEXT NOT NULL DEFAULT '',
|
||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
||||
schema_version INTEGER NOT NULL DEFAULT %d
|
||||
);
|
||||
|
||||
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
|
||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||
}
|
||||
version = 12
|
||||
case 12:
|
||||
// add auto_update_enabled column to settings table
|
||||
if err := db.migrateV12ToV13(); err != nil {
|
||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
||||
}
|
||||
version = 13
|
||||
default:
|
||||
// If we have a version we don't recognize, just set it to current
|
||||
// This might happen during development
|
||||
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
|
||||
func (db *database) migrateV12ToV13() error {
|
||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
||||
if err != nil && !duplicateColumnError(err) {
|
||||
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update schema version: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||
func (db *database) cleanupOrphanedData() error {
|
||||
_, err := db.conn.Exec(`
|
||||
@@ -482,19 +504,11 @@ func (db *database) cleanupOrphanedData() error {
|
||||
}
|
||||
|
||||
func duplicateColumnError(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
||||
}
|
||||
|
||||
func columnNotExists(err error) bool {
|
||||
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||
}
|
||||
return false
|
||||
return err != nil && strings.Contains(err.Error(), "no such column")
|
||||
}
|
||||
|
||||
func (db *database) getAllChats() ([]Chat, error) {
|
||||
@@ -1108,9 +1122,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||
var s Settings
|
||||
|
||||
err := db.conn.QueryRow(`
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
||||
FROM settings
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||
}
|
||||
@@ -1121,8 +1135,8 @@ func (db *database) getSettings() (Settings, error) {
|
||||
func (db *database) setSettings(s Settings) error {
|
||||
_, err := db.conn.Exec(`
|
||||
UPDATE settings
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set settings: %w", err)
|
||||
}
|
||||
|
||||
@@ -169,6 +169,9 @@ type Settings struct {
|
||||
|
||||
// SidebarOpen indicates if the chat sidebar is open
|
||||
SidebarOpen bool
|
||||
|
||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
||||
AutoUpdateEnabled bool
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
|
||||
@@ -413,6 +413,7 @@ export class Settings {
|
||||
ThinkLevel: string;
|
||||
SelectedModel: string;
|
||||
SidebarOpen: boolean;
|
||||
AutoUpdateEnabled: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
@@ -431,6 +432,7 @@ export class Settings {
|
||||
this.ThinkLevel = source["ThinkLevel"];
|
||||
this.SelectedModel = source["SelectedModel"];
|
||||
this.SidebarOpen = source["SidebarOpen"];
|
||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
||||
}
|
||||
}
|
||||
export class SettingsResponse {
|
||||
@@ -467,6 +469,46 @@ export class HealthResponse {
|
||||
this.healthy = source["healthy"];
|
||||
}
|
||||
}
|
||||
export class UpdateInfo {
|
||||
currentVersion: string;
|
||||
availableVersion: string;
|
||||
updateAvailable: boolean;
|
||||
updateDownloaded: boolean;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.currentVersion = source["currentVersion"];
|
||||
this.availableVersion = source["availableVersion"];
|
||||
this.updateAvailable = source["updateAvailable"];
|
||||
this.updateDownloaded = source["updateDownloaded"];
|
||||
}
|
||||
}
|
||||
export class UpdateCheckResponse {
|
||||
updateInfo: UpdateInfo;
|
||||
|
||||
constructor(source: any = {}) {
|
||||
if ('string' === typeof source) source = JSON.parse(source);
|
||||
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
|
||||
}
|
||||
|
||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
||||
if (!a) {
|
||||
return a;
|
||||
}
|
||||
if (Array.isArray(a)) {
|
||||
return (a as any[]).map(elem => this.convertValues(elem, classs));
|
||||
} else if ("object" === typeof a) {
|
||||
if (asMap) {
|
||||
for (const key of Object.keys(a)) {
|
||||
a[key] = new classs(a[key]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
return new classs(a);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
}
|
||||
export class User {
|
||||
id: string;
|
||||
email: string;
|
||||
|
||||
@@ -414,3 +414,54 @@ export async function fetchHealth(): Promise<boolean> {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCurrentVersion(): Promise<string> {
|
||||
try {
|
||||
const response = await fetch(`${API_BASE}/api/version`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
return data.version || "Unknown";
|
||||
}
|
||||
return "Unknown";
|
||||
} catch (error) {
|
||||
console.error("Error fetching version:", error);
|
||||
return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkForUpdate(): Promise<{
|
||||
currentVersion: string;
|
||||
availableVersion: string;
|
||||
updateAvailable: boolean;
|
||||
updateDownloaded: boolean;
|
||||
}> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to check for update");
|
||||
}
|
||||
const data = await response.json();
|
||||
return data.updateInfo;
|
||||
}
|
||||
|
||||
export async function installUpdate(): Promise<void> {
|
||||
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(error || "Failed to install update");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,13 @@ import {
|
||||
XMarkIcon,
|
||||
CogIcon,
|
||||
ArrowLeftIcon,
|
||||
ArrowDownTrayIcon,
|
||||
} from "@heroicons/react/20/solid";
|
||||
import { Settings as SettingsType } from "@/gotypes";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { useUser } from "@/hooks/useUser";
|
||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { getSettings, updateSettings } from "@/api";
|
||||
import { getSettings, updateSettings, checkForUpdate } from "@/api";
|
||||
|
||||
function AnimatedDots() {
|
||||
return (
|
||||
@@ -39,6 +40,12 @@ export default function Settings() {
|
||||
const queryClient = useQueryClient();
|
||||
const [showSaved, setShowSaved] = useState(false);
|
||||
const [restartMessage, setRestartMessage] = useState(false);
|
||||
const [updateInfo, setUpdateInfo] = useState<{
|
||||
currentVersion: string;
|
||||
availableVersion: string;
|
||||
updateAvailable: boolean;
|
||||
updateDownloaded: boolean;
|
||||
} | null>(null);
|
||||
const {
|
||||
user,
|
||||
isAuthenticated,
|
||||
@@ -76,8 +83,22 @@ export default function Settings() {
|
||||
|
||||
useEffect(() => {
|
||||
refetchUser();
|
||||
// Check for updates
|
||||
checkForUpdate()
|
||||
.then(setUpdateInfo)
|
||||
.catch((err) => console.error("Error checking for update:", err));
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// Refresh update info when auto-update toggle changes
|
||||
useEffect(() => {
|
||||
if (settings?.AutoUpdateEnabled !== undefined) {
|
||||
checkForUpdate()
|
||||
.then(setUpdateInfo)
|
||||
.catch((err) => console.error("Error checking for update:", err));
|
||||
}
|
||||
}, [settings?.AutoUpdateEnabled]);
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
const handleFocus = () => {
|
||||
if (isAwaitingConnection && pollingInterval) {
|
||||
@@ -344,6 +365,58 @@ export default function Settings() {
|
||||
{/* Local Configuration */}
|
||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||
<div className="space-y-4 p-4">
|
||||
{/* Auto Update */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="flex items-start space-x-3 flex-1">
|
||||
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
||||
<div className="flex-1">
|
||||
<Label>Auto-download updates</Label>
|
||||
<Description>
|
||||
{settings.AutoUpdateEnabled ? (
|
||||
<>
|
||||
Automatically downloads updates when available.
|
||||
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
|
||||
Current version: {updateInfo?.currentVersion || "Loading..."}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Manually download updates.
|
||||
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
|
||||
<div className="space-y-2 text-sm">
|
||||
<div className="flex justify-between">
|
||||
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
|
||||
</div>
|
||||
{updateInfo?.availableVersion && (
|
||||
<div className="flex justify-between">
|
||||
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<a
|
||||
href="https://ollama.com/download"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
|
||||
>
|
||||
Download new version →
|
||||
</a>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Description>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex-shrink-0">
|
||||
<Switch
|
||||
checked={settings.AutoUpdateEnabled}
|
||||
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
{/* Expose Ollama */}
|
||||
<Field>
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
|
||||
@@ -147,6 +147,7 @@ export const highlighterPromise = createHighlighter({
|
||||
"c",
|
||||
"cpp",
|
||||
"sql",
|
||||
"swift",
|
||||
"yaml",
|
||||
"markdown",
|
||||
],
|
||||
|
||||
@@ -100,6 +100,17 @@ type HealthResponse struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
}
|
||||
|
||||
type UpdateInfo struct {
|
||||
CurrentVersion string `json:"currentVersion"`
|
||||
AvailableVersion string `json:"availableVersion"`
|
||||
UpdateAvailable bool `json:"updateAvailable"`
|
||||
UpdateDownloaded bool `json:"updateDownloaded"`
|
||||
}
|
||||
|
||||
type UpdateCheckResponse struct {
|
||||
UpdateInfo UpdateInfo `json:"updateInfo"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
|
||||
100
app/ui/ui.go
100
app/ui/ui.go
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/ollama/ollama/app/tools"
|
||||
"github.com/ollama/ollama/app/types/not"
|
||||
"github.com/ollama/ollama/app/ui/responses"
|
||||
"github.com/ollama/ollama/app/updater"
|
||||
"github.com/ollama/ollama/app/version"
|
||||
ollamaAuth "github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -106,6 +107,18 @@ type Server struct {
|
||||
|
||||
// Dev is true if the server is running in development mode
|
||||
Dev bool
|
||||
|
||||
// Updater for checking and downloading updates
|
||||
Updater UpdaterInterface
|
||||
UpdateAvailableFunc func()
|
||||
}
|
||||
|
||||
// UpdaterInterface defines the methods we need from the updater
|
||||
type UpdaterInterface interface {
|
||||
CheckForUpdate(ctx context.Context) (bool, string, error)
|
||||
InstallAndRestart() error
|
||||
CancelOngoingDownload()
|
||||
TriggerImmediateCheck()
|
||||
}
|
||||
|
||||
func (s *Server) log() *slog.Logger {
|
||||
@@ -284,6 +297,8 @@ func (s *Server) Handler() http.Handler {
|
||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
|
||||
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
|
||||
|
||||
// Ollama proxy endpoints
|
||||
ollamaProxy := s.ollamaProxy()
|
||||
@@ -1448,6 +1463,24 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||
return fmt.Errorf("failed to save settings: %w", err)
|
||||
}
|
||||
|
||||
// Handle auto-update toggle changes
|
||||
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled: cancel any ongoing download
|
||||
if s.Updater != nil {
|
||||
s.Updater.CancelOngoingDownload()
|
||||
}
|
||||
} else {
|
||||
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
||||
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
||||
s.UpdateAvailableFunc()
|
||||
} else if s.Updater != nil {
|
||||
// Trigger the background checker to run immediately
|
||||
s.Updater.TriggerImmediateCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if old.ContextLength != settings.ContextLength ||
|
||||
old.Models != settings.Models ||
|
||||
old.Expose != settings.Expose {
|
||||
@@ -1524,6 +1557,73 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||
return json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
|
||||
currentVersion := version.Version
|
||||
|
||||
if s.Updater == nil {
|
||||
return fmt.Errorf("updater not available")
|
||||
}
|
||||
|
||||
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
|
||||
if err != nil {
|
||||
s.log().Warn("failed to check for update", "error", err)
|
||||
// Don't return error, just log it and continue with no update available
|
||||
}
|
||||
|
||||
response := responses.UpdateCheckResponse{
|
||||
UpdateInfo: responses.UpdateInfo{
|
||||
CurrentVersion: currentVersion,
|
||||
AvailableVersion: updateVersion,
|
||||
UpdateAvailable: updateAvailable,
|
||||
UpdateDownloaded: updater.UpdateDownloaded,
|
||||
},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return fmt.Errorf("method not allowed")
|
||||
}
|
||||
|
||||
if s.Updater == nil {
|
||||
s.log().Error("install failed: updater not available")
|
||||
return fmt.Errorf("updater not available")
|
||||
}
|
||||
|
||||
// Check if update is downloaded
|
||||
if !updater.UpdateDownloaded {
|
||||
s.log().Error("install failed: no update downloaded")
|
||||
return fmt.Errorf("no update downloaded")
|
||||
}
|
||||
|
||||
// Send response before restarting
|
||||
response := map[string]any{
|
||||
"success": true,
|
||||
"message": "Installing update and restarting...",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Give the response time to be sent
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger the upgrade and restart
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := s.Updater.InstallAndRestart(); err != nil {
|
||||
s.log().Error("failed to install update", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func userAgent() string {
|
||||
buildinfo, _ := debug.ReadBuildInfo()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/app/store"
|
||||
@@ -58,7 +59,8 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
query := requestURL.Query()
|
||||
query.Add("os", runtime.GOOS)
|
||||
query.Add("arch", runtime.GOARCH)
|
||||
query.Add("version", version.Version)
|
||||
currentVersion := version.Version
|
||||
query.Add("version", currentVersion)
|
||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
// The original macOS app used to use the device ID
|
||||
@@ -131,15 +133,27 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||
}
|
||||
|
||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||
// Create a cancellable context for this download
|
||||
downloadCtx, cancel := context.WithCancel(ctx)
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = cancel
|
||||
u.cancelDownloadLock.Unlock()
|
||||
defer func() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
u.cancelDownload = nil
|
||||
u.cancelDownloadLock.Unlock()
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Do a head first to check etag info
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// In case of slow downloads, continue the update check in the background
|
||||
bgctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
||||
defer bgcancel()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
@@ -176,6 +190,7 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||
_, err = os.Stat(stageFilename)
|
||||
if err == nil {
|
||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||
UpdateDownloaded = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,34 +259,95 @@ func cleanupOldDownloads(stageDir string) {
|
||||
}
|
||||
|
||||
type Updater struct {
|
||||
Store *store.Store
|
||||
Store *store.Store
|
||||
cancelDownload context.CancelFunc
|
||||
cancelDownloadLock sync.Mutex
|
||||
checkNow chan struct{}
|
||||
}
|
||||
|
||||
// CancelOngoingDownload cancels any currently running download
|
||||
func (u *Updater) CancelOngoingDownload() {
|
||||
u.cancelDownloadLock.Lock()
|
||||
defer u.cancelDownloadLock.Unlock()
|
||||
if u.cancelDownload != nil {
|
||||
slog.Info("cancelling ongoing update download")
|
||||
u.cancelDownload()
|
||||
u.cancelDownload = nil
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
||||
func (u *Updater) TriggerImmediateCheck() {
|
||||
if u.checkNow != nil {
|
||||
u.checkNow <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||
u.checkNow = make(chan struct{}, 1)
|
||||
go func() {
|
||||
// Don't blast an update message immediately after startup
|
||||
time.Sleep(UpdateCheckInitialDelay)
|
||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||
ticker := time.NewTicker(UpdateCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if available {
|
||||
err := u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
||||
} else {
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug("stopping background update checker")
|
||||
return
|
||||
default:
|
||||
time.Sleep(UpdateCheckInterval)
|
||||
case <-u.checkNow:
|
||||
// Immediate check triggered
|
||||
case <-ticker.C:
|
||||
// Regular interval check
|
||||
}
|
||||
|
||||
// Always check for updates
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
if !available {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update is available - check if auto-update is enabled for downloading
|
||||
settings, err := u.Store.Settings()
|
||||
if err != nil {
|
||||
slog.Error("failed to load settings", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !settings.AutoUpdateEnabled {
|
||||
// Auto-update disabled - don't download, just log
|
||||
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
||||
continue
|
||||
}
|
||||
|
||||
// Auto-update is enabled - download
|
||||
err = u.DownloadNewRelease(ctx, resp)
|
||||
if err != nil {
|
||||
slog.Error("failed to download new release", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Download successful - show tray notification (regardless of toggle state)
|
||||
err = cb(resp.UpdateVersion)
|
||||
if err != nil {
|
||||
slog.Warn("failed to register update available with tray", "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
|
||||
available, resp := u.checkForUpdate(ctx)
|
||||
return available, resp.UpdateVersion, nil
|
||||
}
|
||||
|
||||
func (u *Updater) InstallAndRestart() error {
|
||||
if !UpdateDownloaded {
|
||||
return fmt.Errorf("no update downloaded")
|
||||
}
|
||||
|
||||
slog.Info("installing update and restarting")
|
||||
return DoUpgrade(true)
|
||||
}
|
||||
|
||||
@@ -85,7 +85,17 @@ func TestBackgoundChecker(t *testing.T) {
|
||||
UpdateCheckURLBase = server.URL + "/update.json"
|
||||
|
||||
updater := &Updater{Store: &store.Store{}}
|
||||
defer updater.Store.Close() // Ensure database is closed
|
||||
defer updater.Store.Close()
|
||||
|
||||
settings, err := updater.Store.Settings()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
settings.AutoUpdateEnabled = true
|
||||
if err := updater.Store.SetSettings(settings); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
|
||||
@@ -369,24 +369,24 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
||||
// const ERROR_SUCCESS syscall.Errno = 0
|
||||
func (t *winTray) removeMenuItem(menuItemId, parentId uint32) error {
|
||||
const ERROR_SUCCESS syscall.Errno = 0
|
||||
|
||||
// t.muMenus.RLock()
|
||||
// menu := uintptr(t.menus[parentId])
|
||||
// t.muMenus.RUnlock()
|
||||
// res, _, err := pRemoveMenu.Call(
|
||||
// menu,
|
||||
// uintptr(menuItemId),
|
||||
// MF_BYCOMMAND,
|
||||
// )
|
||||
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||
// return err
|
||||
// }
|
||||
// t.delFromVisibleItems(parentId, menuItemId)
|
||||
t.muMenus.RLock()
|
||||
menu := uintptr(t.menus[parentId])
|
||||
t.muMenus.RUnlock()
|
||||
res, _, err := pRemoveMenu.Call(
|
||||
menu,
|
||||
uintptr(menuItemId),
|
||||
MF_BYCOMMAND,
|
||||
)
|
||||
if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||
return err
|
||||
}
|
||||
t.delFromVisibleItems(parentId, menuItemId)
|
||||
|
||||
// return nil
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *winTray) showMenu() error {
|
||||
p := point{}
|
||||
|
||||
@@ -30,6 +30,7 @@ var (
|
||||
pPostQuitMessage = u32.NewProc("PostQuitMessage")
|
||||
pRegisterClass = u32.NewProc("RegisterClassExW")
|
||||
pRegisterWindowMessage = u32.NewProc("RegisterWindowMessageW")
|
||||
pRemoveMenu = u32.NewProc("RemoveMenu")
|
||||
pSendMessage = u32.NewProc("SendMessageW")
|
||||
pSetForegroundWindow = u32.NewProc("SetForegroundWindow")
|
||||
pSetMenuInfo = u32.NewProc("SetMenuInfo")
|
||||
|
||||
@@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_temperature",
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_temperature",
|
||||
"tool_name": "get_weather"
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
|
||||
@@ -277,6 +277,8 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
||||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
|
||||
|
||||
#### Supported features
|
||||
|
||||
@@ -36,7 +36,6 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
|
||||
}],
|
||||
"stream": false
|
||||
}'
|
||||
"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
|
||||
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
## How can I view the logs?
|
||||
|
||||
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
|
||||
Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
|
||||
|
||||
## Is my GPU compatible with Ollama?
|
||||
|
||||
Please refer to the [GPU docs](./gpu.md).
|
||||
Please refer to the [GPU docs](./gpu).
|
||||
|
||||
## How can I specify the context window size?
|
||||
|
||||
|
||||
10
docs/gpu.mdx
10
docs/gpu.mdx
@@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
|
||||
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
|
||||
|
||||
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
|
||||
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia)
|
||||
|
||||
### GPU Selection
|
||||
|
||||
@@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
|
||||
|
||||
Ollama supports the following AMD GPUs via the ROCm library:
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Additional AMD GPU support is provided by the Vulkan Library - see below.
|
||||
|
||||
|
||||
@@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
|
||||
|
||||
## Vulkan GPU Support
|
||||
|
||||
> [!NOTE]
|
||||
> **NOTE:**
|
||||
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server)
|
||||
|
||||
Additional GPU support on Windows and Linux is provided via
|
||||
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
|
||||
@@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
|
||||
|
||||
To select specific Vulkan GPU(s), you can set the environment variable
|
||||
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
|
||||
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
|
||||
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
|
||||
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
|
||||
by setting `GGML_VK_VISIBLE_DEVICES=-1`
|
||||
@@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
|
||||
|
||||
### Linux NVIDIA Troubleshooting
|
||||
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
|
||||
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker)
|
||||
|
||||
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
|
||||
|
||||
|
||||
@@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 3 +
|
||||
ggml/src/ggml-impl.h | 8 +
|
||||
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
|
||||
ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 +++++++++++
|
||||
9 files changed, 976 insertions(+), 17 deletions(-)
|
||||
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++-
|
||||
ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 ++++++++++
|
||||
9 files changed, 1005 insertions(+), 17 deletions(-)
|
||||
create mode 100644 ggml/src/mem_hip.cpp
|
||||
create mode 100644 ggml/src/mem_nvml.cpp
|
||||
|
||||
@@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644
|
||||
|
||||
set_target_properties(ggml-base PROPERTIES
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 6852d2e20..48cdb1dcf 100644
|
||||
index 6852d2e20..334a30135 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
@@ -109,7 +109,7 @@ index 6852d2e20..48cdb1dcf 100644
|
||||
+
|
||||
+#if defined(GGML_USE_HIP)
|
||||
+ if (ggml_hip_mgmt_init() == 0) {
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
|
||||
+ if (status == 0) {
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
+ ggml_hip_mgmt_release();
|
||||
@@ -204,7 +204,7 @@ index 4e162258d..d89e35a8e 100644
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index fe57d4c58..1c07e767a 100644
|
||||
index fe57d4c58..dba8f4695 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
@@ -216,7 +216,7 @@ index fe57d4c58..1c07e767a 100644
|
||||
+GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
|
||||
+GGML_API void ggml_nvml_release();
|
||||
+GGML_API int ggml_hip_mgmt_init();
|
||||
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
+GGML_API void ggml_hip_mgmt_release();
|
||||
+
|
||||
#ifdef __cplusplus
|
||||
@@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644
|
||||
/* .async = */ true,
|
||||
/* .host_buffer = */ false,
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index 5349bce24..d43d46d1d 100644
|
||||
index 5349bce24..0103fd03a 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -236,6 +236,7 @@ class vk_memory_logger;
|
||||
@@ -334,7 +334,7 @@ index 5349bce24..d43d46d1d 100644
|
||||
+ switch (props2.properties.vendorID) {
|
||||
+ case VK_VENDOR_ID_AMD:
|
||||
+ if (ggml_hip_mgmt_init() == 0) {
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
|
||||
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
|
||||
+ if (status == 0) {
|
||||
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
+ ggml_hip_mgmt_release();
|
||||
@@ -505,10 +505,10 @@ index 5349bce24..d43d46d1d 100644
|
||||
}
|
||||
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
|
||||
new file mode 100644
|
||||
index 000000000..c1949b899
|
||||
index 000000000..23c765806
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_hip.cpp
|
||||
@@ -0,0 +1,529 @@
|
||||
@@ -0,0 +1,558 @@
|
||||
+#include "ggml.h"
|
||||
+#include "ggml-impl.h"
|
||||
+
|
||||
@@ -842,7 +842,7 @@ index 000000000..c1949b899
|
||||
+ if (gpus != NULL) gpus->pVtbl->Release(gpus); \
|
||||
+ if (gpu != NULL) gpu->pVtbl->Release(gpu)
|
||||
+
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
+ std::lock_guard<std::mutex> lock(ggml_adlx_lock);
|
||||
+ if (adlx.handle == NULL) {
|
||||
+ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
|
||||
@@ -966,13 +966,16 @@ index 000000000..c1949b899
|
||||
+ return 0;
|
||||
+}
|
||||
+void ggml_hip_mgmt_release() {}
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
|
||||
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
|
||||
+ const std::string drmTotalMemoryFile = "mem_info_vram_total";
|
||||
+ const std::string drmUsedMemoryFile = "mem_info_vram_used";
|
||||
+ const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
|
||||
+ const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
|
||||
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
|
||||
+
|
||||
+
|
||||
+ glob_t glob_result;
|
||||
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
|
||||
+
|
||||
@@ -1006,7 +1009,6 @@ index 000000000..c1949b899
|
||||
+
|
||||
+ uint64_t memory;
|
||||
+ totalFileStream >> memory;
|
||||
+ *total = memory;
|
||||
+
|
||||
+ std::string usedFile = dir + "/" + drmUsedMemoryFile;
|
||||
+ std::ifstream usedFileStream(usedFile.c_str());
|
||||
@@ -1019,6 +1021,33 @@ index 000000000..c1949b899
|
||||
+
|
||||
+ uint64_t memoryUsed;
|
||||
+ usedFileStream >> memoryUsed;
|
||||
+
|
||||
+ if (is_integrated_gpu) {
|
||||
+ std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
|
||||
+ std::ifstream totalFileStream(totalFile.c_str());
|
||||
+ if (!totalFileStream.is_open()) {
|
||||
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+ uint64_t gtt;
|
||||
+ totalFileStream >> gtt;
|
||||
+ std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
|
||||
+ std::ifstream usedFileStream(usedFile.c_str());
|
||||
+ if (!usedFileStream.is_open()) {
|
||||
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
|
||||
+ file.close();
|
||||
+ globfree(&glob_result);
|
||||
+ return 1;
|
||||
+ }
|
||||
+ uint64_t gttUsed;
|
||||
+ usedFileStream >> gttUsed;
|
||||
+ memory += gtt;
|
||||
+ memoryUsed += gttUsed;
|
||||
+ }
|
||||
+
|
||||
+ *total = memory;
|
||||
+ *free = memory - memoryUsed;
|
||||
+
|
||||
+ file.close();
|
||||
|
||||
@@ -24,12 +24,12 @@ index 99ae293cc..9a134b7af 100644
|
||||
|
||||
set_target_properties(ggml-base PROPERTIES
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index 1c07e767a..0da3e065b 100644
|
||||
index dba8f4695..7e17032c7 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
|
||||
GGML_API int ggml_hip_mgmt_init();
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
GGML_API void ggml_hip_mgmt_release();
|
||||
+GGML_API int ggml_dxgi_pdh_init();
|
||||
+GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
@@ -38,7 +38,7 @@ index 1c07e767a..0da3e065b 100644
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
index d43d46d1d..df79f9f79 100644
|
||||
index 0103fd03a..9cc4ebdef 100644
|
||||
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
|
||||
@@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
||||
|
||||
@@ -10,7 +10,7 @@ fallback to cpu
|
||||
1 file changed, 3 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 48cdb1dcf..3102d7ea7 100644
|
||||
index 334a30135..5c9dfd032 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
|
||||
@@ -4436,7 +4436,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
if (ggml_hip_mgmt_init() == 0) {
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0);
|
||||
if (status == 0) {
|
||||
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
|
||||
ggml_hip_mgmt_release();
|
||||
|
||||
2
ml/backend/ggml/ggml/src/ggml-impl.h
vendored
2
ml/backend/ggml/ggml/src/ggml-impl.h
vendored
@@ -682,7 +682,7 @@ GGML_API int ggml_nvml_init();
|
||||
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
|
||||
GGML_API void ggml_nvml_release();
|
||||
GGML_API int ggml_hip_mgmt_init();
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
|
||||
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
GGML_API void ggml_hip_mgmt_release();
|
||||
GGML_API int ggml_dxgi_pdh_init();
|
||||
GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
|
||||
|
||||
@@ -13710,7 +13710,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
|
||||
switch (props2.properties.vendorID) {
|
||||
case VK_VENDOR_ID_AMD:
|
||||
if (ggml_hip_mgmt_init() == 0) {
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
|
||||
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu);
|
||||
if (status == 0) {
|
||||
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
|
||||
ggml_hip_mgmt_release();
|
||||
|
||||
35
ml/backend/ggml/ggml/src/mem_hip.cpp
vendored
35
ml/backend/ggml/ggml/src/mem_hip.cpp
vendored
@@ -331,7 +331,7 @@ void ggml_hip_mgmt_release() {
|
||||
if (gpus != NULL) gpus->pVtbl->Release(gpus); \
|
||||
if (gpu != NULL) gpu->pVtbl->Release(gpu)
|
||||
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
std::lock_guard<std::mutex> lock(ggml_adlx_lock);
|
||||
if (adlx.handle == NULL) {
|
||||
GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
|
||||
@@ -455,13 +455,16 @@ int ggml_hip_mgmt_init() {
|
||||
return 0;
|
||||
}
|
||||
void ggml_hip_mgmt_release() {}
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) {
|
||||
GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
|
||||
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
|
||||
const std::string drmTotalMemoryFile = "mem_info_vram_total";
|
||||
const std::string drmUsedMemoryFile = "mem_info_vram_used";
|
||||
const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
|
||||
const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
|
||||
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
|
||||
|
||||
|
||||
glob_t glob_result;
|
||||
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
|
||||
|
||||
@@ -495,7 +498,6 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
|
||||
uint64_t memory;
|
||||
totalFileStream >> memory;
|
||||
*total = memory;
|
||||
|
||||
std::string usedFile = dir + "/" + drmUsedMemoryFile;
|
||||
std::ifstream usedFileStream(usedFile.c_str());
|
||||
@@ -508,6 +510,33 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
|
||||
|
||||
uint64_t memoryUsed;
|
||||
usedFileStream >> memoryUsed;
|
||||
|
||||
if (is_integrated_gpu) {
|
||||
std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
|
||||
std::ifstream totalFileStream(totalFile.c_str());
|
||||
if (!totalFileStream.is_open()) {
|
||||
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
uint64_t gtt;
|
||||
totalFileStream >> gtt;
|
||||
std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
|
||||
std::ifstream usedFileStream(usedFile.c_str());
|
||||
if (!usedFileStream.is_open()) {
|
||||
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
|
||||
file.close();
|
||||
globfree(&glob_result);
|
||||
return 1;
|
||||
}
|
||||
uint64_t gttUsed;
|
||||
usedFileStream >> gttUsed;
|
||||
memory += gtt;
|
||||
memoryUsed += gttUsed;
|
||||
}
|
||||
|
||||
*total = memory;
|
||||
*free = memory - memoryUsed;
|
||||
|
||||
file.close();
|
||||
|
||||
@@ -752,9 +752,15 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return err
|
||||
}
|
||||
// TODO: this first normalization should be done by the model
|
||||
embedding = normalize(embedding)
|
||||
embedding, err = normalize(embedding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
|
||||
embedding = normalize(embedding[:req.Dimensions])
|
||||
embedding, err = normalize(embedding[:req.Dimensions])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||
@@ -787,9 +793,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func normalize(vec []float32) []float32 {
|
||||
func normalize(vec []float32) ([]float32, error) {
|
||||
var sum float32
|
||||
for _, v := range vec {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
return nil, errors.New("embedding contains NaN or Inf values")
|
||||
}
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
@@ -797,7 +806,7 @@ func normalize(vec []float32) []float32 {
|
||||
for i := range vec {
|
||||
vec[i] *= norm
|
||||
}
|
||||
return vec
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
@@ -2395,4 +2404,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
|
||||
@@ -723,15 +723,20 @@ func TestShow(t *testing.T) {
|
||||
|
||||
func TestNormalize(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []float32
|
||||
input []float32
|
||||
expectError bool
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{input: []float32{1}},
|
||||
{input: []float32{0, 1, 2, 3}},
|
||||
{input: []float32{0.1, 0.2, 0.3}},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float32{0, 0, 0}},
|
||||
{input: []float32{1}, expectError: false},
|
||||
{input: []float32{0, 1, 2, 3}, expectError: false},
|
||||
{input: []float32{0.1, 0.2, 0.3}, expectError: false},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false},
|
||||
{input: []float32{0, 0, 0}, expectError: false},
|
||||
{input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true},
|
||||
{input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true},
|
||||
{input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true},
|
||||
{input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true},
|
||||
}
|
||||
|
||||
isNormalized := func(vec []float32) (res bool) {
|
||||
@@ -748,9 +753,18 @@ func TestNormalize(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
normalized := normalize(tc.input)
|
||||
if !isNormalized(normalized) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
normalized, err := normalize(tc.input)
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for input %v, but got none", tc.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for input %v: %v", tc.input, err)
|
||||
}
|
||||
if !isNormalized(normalized) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user