Compare commits

..

3 Commits

Author SHA1 Message Date
Eva Ho
42d6a3f075 proper clear draft message 2025-12-12 16:56:44 -05:00
Eva Ho
ed553f51f7 clean up 2025-12-12 16:11:03 -05:00
Eva Ho
7d6f0c621f adding draft for each chat to remember unsent prompts 2025-12-12 15:59:47 -05:00
29 changed files with 488 additions and 1348 deletions

View File

@@ -305,6 +305,9 @@ func main() {
go func() {
<-signals
slog.Info("received SIGINT or SIGTERM signal, shutting down")
if err := st.ClearAllDrafts(); err != nil {
slog.Warn("failed to clear drafts on shutdown", "error", err)
}
quit()
}()

View File

@@ -182,6 +182,11 @@ func osRun(_ func(), hasCompletedFirstRun, startHidden bool) {
}
func quit() {
if wv.Store != nil {
if err := wv.Store.ClearAllDrafts(); err != nil {
slog.Warn("failed to clear drafts on quit", "error", err)
}
}
C.quit()
}

View File

@@ -111,6 +111,11 @@ func (*appCallbacks) UIRunning() bool {
}
func (app *appCallbacks) Quit() {
if wv.Store != nil {
if err := wv.Store.ClearAllDrafts(); err != nil {
slog.Warn("failed to clear drafts on quit", "error", err)
}
}
app.t.Quit()
wv.Terminate()
}

6
app/package-lock.json generated Normal file
View File

@@ -0,0 +1,6 @@
{
"name": "app",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

View File

@@ -14,7 +14,7 @@ import (
// currentSchemaVersion defines the current database schema version.
// Increment this when making schema changes that require migrations.
const currentSchemaVersion = 12
const currentSchemaVersion = 13
// database wraps the SQLite connection.
// SQLite handles its own locking for concurrent access:
@@ -95,7 +95,8 @@ func (db *database) init() error {
id TEXT PRIMARY KEY,
title TEXT NOT NULL DEFAULT '',
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
browser_state TEXT
browser_state TEXT,
draft TEXT NOT NULL DEFAULT ''
);
CREATE TABLE IF NOT EXISTS messages (
@@ -244,6 +245,12 @@ func (db *database) migrate() error {
return fmt.Errorf("migrate v11 to v12: %w", err)
}
version = 12
case 12:
// add draft column to chats table
if err := db.migrateV12ToV13(); err != nil {
return fmt.Errorf("migrate v12 to v13: %w", err)
}
version = 13
default:
// If we have a version we don't recognize, just set it to current
// This might happen during development
@@ -452,6 +459,21 @@ func (db *database) migrateV11ToV12() error {
return nil
}
// migrateV12ToV13 adds the draft column to the chats table
func (db *database) migrateV12ToV13() error {
_, err := db.conn.Exec(`ALTER TABLE chats ADD COLUMN draft TEXT NOT NULL DEFAULT ''`)
if err != nil && !duplicateColumnError(err) {
return fmt.Errorf("add draft column: %w", err)
}
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
if err != nil {
return fmt.Errorf("update schema version: %w", err)
}
return nil
}
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
func (db *database) cleanupOrphanedData() error {
_, err := db.conn.Exec(`
@@ -570,7 +592,7 @@ func (db *database) getAllChats() ([]Chat, error) {
func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Chat, error) {
query := `
SELECT id, title, created_at, browser_state
SELECT id, title, created_at, browser_state, draft
FROM chats
WHERE id = ?
`
@@ -578,12 +600,14 @@ func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Cha
var chat Chat
var createdAt time.Time
var browserState sql.NullString
var draft sql.NullString
err := db.conn.QueryRow(query, id).Scan(
&chat.ID,
&chat.Title,
&createdAt,
&browserState,
&draft,
)
if err != nil {
if err == sql.ErrNoRows {
@@ -599,6 +623,9 @@ func (db *database) getChatWithOptions(id string, loadAttachmentData bool) (*Cha
chat.BrowserState = raw
}
}
if draft.Valid {
chat.Draft = draft.String
}
messages, err := db.getMessages(id, loadAttachmentData)
if err != nil {
@@ -622,11 +649,12 @@ func (db *database) saveChat(chat Chat) error {
// UPSERT would overwrite browser_state with NULL, breaking revisit rendering that relies
// on the last persisted full tool state.
query := `
INSERT INTO chats (id, title, created_at, browser_state)
VALUES (?, ?, ?, ?)
INSERT INTO chats (id, title, created_at, browser_state, draft)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
title = excluded.title,
browser_state = COALESCE(excluded.browser_state, chats.browser_state)
browser_state = COALESCE(excluded.browser_state, chats.browser_state),
draft = excluded.draft
`
var browserState sql.NullString
@@ -639,6 +667,7 @@ func (db *database) saveChat(chat Chat) error {
chat.Title,
chat.CreatedAt,
browserState,
chat.Draft,
)
if err != nil {
return fmt.Errorf("save chat: %w", err)
@@ -669,6 +698,23 @@ func (db *database) saveChat(chat Chat) error {
return tx.Commit()
}
// updateChatDraft updates only the draft for a chat
func (db *database) updateChatDraft(chatID string, draft string) error {
_, err := db.conn.Exec(`UPDATE chats SET draft = ? WHERE id = ?`, draft, chatID)
if err != nil {
return fmt.Errorf("update chat draft: %w", err)
}
return nil
}
func (db *database) clearAllDrafts() error {
_, err := db.conn.Exec(`UPDATE chats SET draft = ''`)
if err != nil {
return fmt.Errorf("clear all drafts: %w", err)
}
return nil
}
// updateChatBrowserState updates only the browser_state for a chat
func (db *database) updateChatBrowserState(chatID string, state json.RawMessage) error {
_, err := db.conn.Exec(`UPDATE chats SET browser_state = ? WHERE id = ?`, string(state), chatID)

View File

@@ -109,6 +109,7 @@ type Chat struct {
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
BrowserState json.RawMessage `json:"browser_state,omitempty" ts_type:"BrowserStateData"`
Draft string `json:"draft,omitempty"`
}
// NewChat creates a new Chat with the ID, with CreatedAt timestamp initialized
@@ -451,6 +452,22 @@ func (s *Store) AppendMessage(chatID string, message Message) error {
return s.db.appendMessage(chatID, message)
}
func (s *Store) UpdateChatDraft(chatID string, draft string) error {
if err := s.ensureDB(); err != nil {
return err
}
return s.db.updateChatDraft(chatID, draft)
}
func (s *Store) ClearAllDrafts() error {
if err := s.ensureDB(); err != nil {
return err
}
return s.db.clearAllDrafts()
}
func (s *Store) UpdateChatBrowserState(chatID string, state json.RawMessage) error {
if err := s.ensureDB(); err != nil {
return err

View File

@@ -159,6 +159,7 @@ export class Chat {
title: string;
created_at: Time;
browser_state?: BrowserStateData;
draft?: string;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
@@ -167,6 +168,7 @@ export class Chat {
this.title = source["title"];
this.created_at = this.convertValues(source["created_at"], Time);
this.browser_state = source["browser_state"];
this.draft = source["draft"];
}
convertValues(a: any, classs: any, asMap: boolean = false): any {

View File

@@ -299,6 +299,20 @@ export async function renameChat(chatId: string, title: string): Promise<void> {
}
}
export async function updateChatDraft(chatId: string, draft: string): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}/draft`, {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ draft }),
});
if (!response.ok) {
const error = await response.text();
throw new Error(error || "Failed to update draft");
}
}
export async function deleteChat(chatId: string): Promise<void> {
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "DELETE",

View File

@@ -282,6 +282,7 @@ export default function Chat({ chatId }: { chatId: string }) {
onSubmit={handleChatFormSubmit}
chatId={chatId}
autoFocus={true}
initialDraft={chatQuery?.data?.chat?.draft ?? ""}
editingMessage={editingMessage}
onCancelEdit={handleCancelEdit}
isDisabled={isDisabled}

View File

@@ -27,6 +27,7 @@ import { ErrorMessage } from "./ErrorMessage";
import { processFiles } from "@/utils/fileValidation";
import type { ImageData } from "@/types/webview";
import { PlusIcon } from "@heroicons/react/24/outline";
import { useDraftMessage } from "@/hooks/useDraftMessage";
export type ThinkingLevel = "low" | "medium" | "high";
@@ -62,6 +63,7 @@ interface ChatFormProps {
chatId?: string;
isDownloadingModel?: boolean;
isDisabled?: boolean;
initialDraft?: string;
// Editing props - when provided, ChatForm enters edit mode
editingMessage?: {
content: string;
@@ -84,6 +86,7 @@ function ChatForm({
chatId = "new",
isDownloadingModel = false,
isDisabled = false,
initialDraft,
editingMessage,
onCancelEdit,
onFilesReceived,
@@ -118,6 +121,8 @@ function ChatForm({
null,
);
const { saveDraft, clearDraft } = useDraftMessage(chatId);
const handleThinkingLevelDropdownToggle = (isOpen: boolean) => {
if (
isOpen &&
@@ -308,10 +313,39 @@ function ChatForm({
}
}, [editingMessage]);
// Clear composition and reset textarea height when chatId changes
useEffect(() => {
resetChatForm();
}, [chatId]);
if (editingMessage) {
return;
}
if (initialDraft && initialDraft.trim()) {
setMessage({
content: initialDraft,
attachments: [],
fileErrors: [],
});
// Adjust textarea height after loading draft
setTimeout(() => {
if (textareaRef.current && initialDraft) {
textareaRef.current.style.height = "auto";
textareaRef.current.style.height =
Math.min(textareaRef.current.scrollHeight, 24 * 8) + "px";
}
}, 0);
} else {
resetChatForm();
}
}, [chatId, initialDraft, editingMessage]);
// Save draft only when navigating away or on blur
useEffect(() => {
return () => {
if (!editingMessage && message.content.trim()) {
saveDraft(message.content);
}
};
}, [message.content, editingMessage, saveDraft]);
// Auto-focus textarea when autoFocus is true or when streaming completes (but not when editing)
useEffect(() => {
@@ -511,12 +545,13 @@ function ChatForm({
});
}
// Clear composition after successful submission
// Clear composition and draft after successful submission
setMessage({
content: "",
attachments: [],
fileErrors: [],
});
clearDraft();
// Reset textarea height and refocus after submit
setTimeout(() => {
@@ -621,6 +656,13 @@ function ChatForm({
e.target.style.height = Math.min(e.target.scrollHeight, 24 * 8) + "px";
};
// Save draft when textarea loses focus
const handleTextareaBlur = () => {
if (!editingMessage && message.content.trim()) {
saveDraft(message.content);
}
};
const handleFilesUpload = async () => {
try {
setFileUploadError(null);
@@ -832,6 +874,7 @@ function ChatForm({
ref={textareaRef}
value={message.content}
onChange={handleTextareaChange}
onBlur={handleTextareaBlur}
placeholder="Send a message"
disabled={isDisabled}
className={`allow-context-menu w-full overflow-y-auto text-neutral-700 outline-none resize-none border-none bg-transparent dark:text-white placeholder:text-neutral-400 dark:placeholder:text-neutral-500 min-h-[24px] leading-6 transition-opacity duration-300 ${

View File

@@ -16,7 +16,6 @@ import {
ArrowLeftIcon,
} from "@heroicons/react/20/solid";
import { Settings as SettingsType } from "@/gotypes";
import { useNavigate } from "@tanstack/react-router";
import { useUser } from "@/hooks/useUser";
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { getSettings, updateSettings } from "@/api";
@@ -52,7 +51,6 @@ export default function Settings() {
const [isAwaitingConnection, setIsAwaitingConnection] = useState(false);
const [connectionError, setConnectionError] = useState<string | null>(null);
const [pollingInterval, setPollingInterval] = useState<number | null>(null);
const navigate = useNavigate();
const {
data: settingsData,
@@ -216,7 +214,7 @@ export default function Settings() {
>
{isWindows && (
<button
onClick={() => navigate({ to: "/" })}
onClick={() => window.history.back()}
className="hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full p-1.5"
>
<ArrowLeftIcon className="w-5 h-5 dark:text-white" />
@@ -226,7 +224,7 @@ export default function Settings() {
</h1>
{!isWindows && (
<button
onClick={() => navigate({ to: "/" })}
onClick={() => window.history.back()}
className="p-1 hover:bg-neutral-100 mr-3 dark:hover:bg-neutral-800 rounded-full"
>
<XMarkIcon className="w-6 h-6 dark:text-white" />

View File

@@ -0,0 +1,34 @@
import { useCallback } from "react";
import { updateChatDraft } from "@/api";
export function useDraftMessage(chatId: string) {
const saveDraft = useCallback(async (content: string) => {
try {
if (chatId === "new") {
return;
}
await updateChatDraft(chatId, content);
} catch (error) {
console.error("Error saving draft message:", error);
}
}, [chatId]);
const clearDraft = useCallback(async () => {
try {
if (chatId === "new") {
return;
}
await updateChatDraft(chatId, "");
} catch (error) {
console.error("Error clearing draft message:", error);
}
}, [chatId]);
return {
saveDraft,
clearDraft,
};
}

View File

@@ -12,13 +12,13 @@ import (
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"os"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -117,66 +117,40 @@ func (s *Server) log() *slog.Logger {
// ollamaProxy creates a reverse proxy handler to the Ollama server
func (s *Server) ollamaProxy() http.Handler {
var (
proxy http.Handler
proxyMu sync.Mutex
)
ollamaHost := os.Getenv("OLLAMA_HOST")
if ollamaHost == "" {
ollamaHost = "http://127.0.0.1:11434"
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyMu.Lock()
p := proxy
proxyMu.Unlock()
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
ollamaHost = "http://" + ollamaHost
}
if p == nil {
proxyMu.Lock()
if proxy == nil {
var err error
for i := range 2 {
if i > 0 {
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
time.Sleep(1 * time.Second)
}
target, err := url.Parse(ollamaHost)
if err != nil {
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
})
}
err = WaitForServer(context.Background(), 10*time.Second)
if err == nil {
break
}
}
s.log().Info("configuring ollama proxy", "target", target.String())
if err != nil {
proxyMu.Unlock()
s.log().Error("ollama server not ready after retries", "error", err)
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
return
}
proxy := httputil.NewSingleHostReverseProxy(target)
target := envconfig.Host()
s.log().Info("configuring ollama proxy", "target", target.String())
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
newProxy := httputil.NewSingleHostReverseProxy(target)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
originalDirector := newProxy.Director
newProxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
proxy = newProxy
p = newProxy
} else {
p = proxy
}
proxyMu.Unlock()
}
p.ServeHTTP(w, r)
})
return proxy
}
type errHandlerFunc func(http.ResponseWriter, *http.Request) error
@@ -279,6 +253,7 @@ func (s *Server) Handler() http.Handler {
mux.Handle("DELETE /api/v1/chat/{id}", handle(s.deleteChat))
mux.Handle("POST /api/v1/create-chat", handle(s.createChat))
mux.Handle("PUT /api/v1/chat/{id}/rename", handle(s.renameChat))
mux.Handle("PUT /api/v1/chat/{id}/draft", handle(s.updateDraft))
mux.Handle("GET /api/v1/inference-compute", handle(s.getInferenceCompute))
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
@@ -1302,6 +1277,28 @@ func (s *Server) renameChat(w http.ResponseWriter, r *http.Request) error {
return nil
}
func (s *Server) updateDraft(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {
return fmt.Errorf("chat ID is required")
}
var req struct {
Draft string `json:"draft"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return fmt.Errorf("invalid request body: %w", err)
}
if err := s.Store.UpdateChatDraft(cid, req.Draft); err != nil {
return fmt.Errorf("failed to update draft: %w", err)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
return nil
}
func (s *Server) deleteChat(w http.ResponseWriter, r *http.Request) error {
cid := r.PathValue("id")
if cid == "" {

View File

@@ -13,7 +13,6 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil"
"github.com/ollama/ollama/ml"
)
type GGML struct {
@@ -551,7 +550,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength()
@@ -792,7 +791,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention == ml.FlashAttentionEnabled {
if useFlashAttention {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
@@ -810,14 +809,6 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]

View File

@@ -118,7 +118,7 @@ type ContextParams struct {
c C.struct_llama_context_params
}
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize * numSeqMax)
@@ -127,13 +127,10 @@ func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, fla
params.n_threads = C.int(threads)
params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true)
switch flashAttention {
case ml.FlashAttentionEnabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
case ml.FlashAttentionDisabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
case ml.FlashAttentionAuto:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
if flashAttention {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
} else {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
}
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

View File

@@ -188,11 +188,6 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if len(projectors) > 0 && llamaModel != nil {
loadRequest.ProjectorPath = projectors[0]
}
// Determine if the user has forced FA on or off
faUserSet := false
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
faUserSet = true
}
fa := envconfig.FlashAttention(f.FlashAttention())
@@ -210,51 +205,19 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if textProcessor == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
flashAttention = ml.FlashAttentionEnabled
} else {
flashAttention = ml.FlashAttentionDisabled
}
}
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = true
if kvct != "" {
if f.KVCacheTypeIsQuantized(kvct) {
if flashAttention != ml.FlashAttentionEnabled {
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
loadRequest.KvCacheType = ""
} else if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
} else {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
}
}
loadRequest.FlashAttention = flashAttention
} else {
// For Ollama engine, use our SupportsFlashAttention logic
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = ml.FlashAttentionEnabled
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
}
gpuLibs := ml.LibraryPaths(gpus)
@@ -472,7 +435,7 @@ type LoadRequest struct {
LoraPath []string
Parallel int
BatchSize int
FlashAttention ml.FlashAttentionType
FlashAttention bool
KvSize int
KvCacheType string
NumThreads int

View File

@@ -74,7 +74,7 @@ type BackendParams struct {
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention FlashAttentionType
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))

View File

@@ -109,7 +109,7 @@ type Backend struct {
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
flashAttention ml.FlashAttentionType
flashAttention bool
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention == ml.FlashAttentionEnabled {
if b.flashAttention {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention == ml.FlashAttentionEnabled {
if t.b.flashAttention {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)

View File

@@ -492,32 +492,6 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
return true
}
type FlashAttentionType int32
const (
// Aligned with llama_flash_attn_type
FlashAttentionAuto FlashAttentionType = -1
FlashAttentionDisabled FlashAttentionType = 0
FlashAttentionEnabled FlashAttentionType = 1
)
func (f FlashAttentionType) LogValue() slog.Value {
return slog.AnyValue(f.String())
}
func (f FlashAttentionType) String() string {
switch f {
case FlashAttentionAuto:
return "Auto"
case FlashAttentionDisabled:
return "Disabled"
case FlashAttentionEnabled:
return "Enabled"
default:
return "unknown"
}
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices

View File

@@ -2,6 +2,7 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -12,26 +13,25 @@ import (
)
type TextConfig struct {
hiddenSize, contextLength, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindow uint32
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
}
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOriginalContext),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
@@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi
)
}
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@@ -55,9 +55,6 @@ type TextModel struct {
const (
gemmaGlobalCacheCount = 6
gemma1BLayerCount = 26
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
@@ -73,7 +70,6 @@ func newTextModel(c fs.Config) *TextModel {
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
contextLength: int(c.Uint("context_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
@@ -81,7 +77,6 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindow: c.Uint("attention.sliding_window"),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
@@ -93,20 +88,14 @@ func newTextModel(c fs.Config) *TextModel {
},
}
// Apply corrections for older versions of the Gemma 3 models
// by looking at whether they use sliding window attention and
// based on their layer counts.
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
switch numBlocks {
case gemma1BLayerCount:
// The 1B model has final logit softcapping set to 30.0
// but it should be 0.0
m.TextConfig.finalLogitSoftcap = 0.0
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
// The 4B, 12B, and 27B models have rope scale unset
// but it shuold be set to 8.0
m.TextConfig.ropeScale = 8.0
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0
}
if numBlocks == gemma27BLayerCount {
@@ -125,31 +114,31 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
return opts.ropeLocalBase, 1.0
return opts.ropeLocalBase
}
// Standard Gemma3: only every n-th layer is global,
// where n = gemmaGlobalCacheCount, otherwise use
// the local rope base
if (layer+1)%gemmaGlobalCacheCount > 0 {
return opts.ropeLocalBase, 1.0
return opts.ropeLocalBase
}
// default to global rope base
return opts.ropeBase, opts.ropeScale
return opts.ropeBase
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
ropeBase := opts.ropeBaseForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -160,7 +149,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -173,8 +162,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
}
type TextMLP struct {

View File

@@ -1,292 +0,0 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type DeepSeekParserState int
const (
DeepSeekCollectingThinking DeepSeekParserState = iota
DeepSeekCollectingContent
DeepSeekCollectingToolCalls
DeepSeekCollectingToolOutput
)
const (
deepseekThinkingCloseTag = "</think>"
deepseekToolCallsBeginTag = "<tool▁calls▁begin>"
deepseekToolCallsEndTag = "<tool▁calls▁end>"
deepseekToolCallBeginTag = "<tool▁call▁begin>"
deepseekToolCallEndTag = "<tool▁call▁end>"
deepseekToolSepTag = "<tool▁sep>"
deepseekToolOutputBeginTag = "<tool▁output▁begin>"
deepseekToolOutputEndTag = "<tool▁output▁end>"
)
type DeepSeekParser struct {
state DeepSeekParserState
buffer strings.Builder
hasThinkingSupport bool
}
func (p *DeepSeekParser) HasToolSupport() bool {
return true
}
func (p *DeepSeekParser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *DeepSeekParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := p.HasThinkingSupport() && (thinkValue == nil || thinkValue.Bool())
if !thinkingEnabled {
p.state = DeepSeekCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = DeepSeekCollectingContent
return
}
p.state = DeepSeekCollectingThinking
}
func (p *DeepSeekParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, tools, thinkValue)
return tools
}
type deepseekEvent interface {
isDeepSeekEvent()
}
type deepseekEventThinkingContent struct {
content string
}
type deepseekEventContent struct {
content string
}
type deepseekEventToolCall struct {
toolCall api.ToolCall
}
func (deepseekEventThinkingContent) isDeepSeekEvent() {}
func (deepseekEventContent) isDeepSeekEvent() {}
func (deepseekEventToolCall) isDeepSeekEvent() {}
func (p *DeepSeekParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case deepseekEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case deepseekEventThinkingContent:
thinkingSb.WriteString(event.content)
case deepseekEventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *DeepSeekParser) parseEvents() []deepseekEvent {
var all []deepseekEvent
keepLooping := true
for keepLooping {
var events []deepseekEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *DeepSeekParser) eat() ([]deepseekEvent, bool) {
var events []deepseekEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case DeepSeekCollectingThinking:
if strings.Contains(bufStr, deepseekThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, deepseekThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = DeepSeekCollectingContent
if len(thinking) > 0 {
events = append(events, deepseekEventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, deepseekThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, deepseekEventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, deepseekEventThinkingContent{content: unambiguous})
}
return events, false
}
case DeepSeekCollectingContent:
switch {
case strings.Contains(bufStr, deepseekToolCallsBeginTag): // content[<tool▁calls▁begin>] -> tool calls
split := strings.SplitN(bufStr, deepseekToolCallsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = DeepSeekCollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, deepseekEventContent{content: contentBefore})
}
return events, true
case strings.Contains(bufStr, deepseekToolOutputBeginTag): // content[<tool▁output▁begin>] -> tool output
split := strings.SplitN(bufStr, deepseekToolOutputBeginTag, 2)
contentBefore := split[0] // Don't trim whitespace - preserve spaces
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = DeepSeekCollectingToolOutput
if len(contentBefore) > 0 {
events = append(events, deepseekEventContent{content: contentBefore})
}
return events, true
default: // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, deepseekEventContent{content: bufStr})
}
return events, false
}
case DeepSeekCollectingToolCalls:
if idx := strings.Index(bufStr, deepseekToolCallBeginTag); idx != -1 {
startIdx := idx + len(deepseekToolCallBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], deepseekToolCallEndTag); endIdx != -1 {
toolCallContent := bufStr[startIdx : startIdx+endIdx]
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
remaining := bufStr[startIdx+endIdx+len(deepseekToolCallEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
events = append(events, deepseekEventToolCall{toolCall: toolCall})
return events, true
} else {
slog.Warn("deepseek tool call parsing failed", "error", err)
}
}
}
if idx := strings.Index(bufStr, deepseekToolCallsEndTag); idx != -1 {
remaining := bufStr[idx+len(deepseekToolCallsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = DeepSeekCollectingContent
return events, true
}
return events, false
case DeepSeekCollectingToolOutput:
if idx := strings.Index(bufStr, deepseekToolOutputEndTag); idx != -1 {
toolOutputContent := bufStr[:idx]
remaining := bufStr[idx+len(deepseekToolOutputEndTag):]
// Don't trim whitespace - preserve spaces after tool output tags
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = DeepSeekCollectingContent
if len(toolOutputContent) > 0 {
events = append(events, deepseekEventContent{content: toolOutputContent})
}
return events, true
}
return events, false
}
return events, false
}
func (p *DeepSeekParser) parseToolCallContent(content string) (api.ToolCall, error) {
// Expected format: tool_name<tool▁sep>{args}
parts := strings.SplitN(content, deepseekToolSepTag, 2)
if len(parts) < 2 {
return api.ToolCall{}, errors.New("invalid format")
}
toolName := strings.TrimSpace(parts[0])
argsJSON := strings.TrimSpace(parts[1])
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return api.ToolCall{}, err
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: toolName,
Arguments: args,
},
}, nil
}

View File

@@ -1,721 +0,0 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestDeepSeekParser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
hasThinking bool
}{
{
name: "simple_content",
input: "Hello, how are you?",
expectedContent: "Hello, how are you?",
hasThinking: false,
},
{
name: "thinking_content",
input: "I need to think about this...</think>The answer is 42.",
expectedThinking: "I need to think about this...",
expectedContent: "The answer is 42.",
hasThinking: true,
},
{
name: "no_thinking_simple",
input: "Just a regular response.",
expectedContent: "Just a regular response.",
hasThinking: false,
},
{
name: "thinking_with_newlines",
input: "Let me think:\n- Point 1\n- Point 2</think>\n\nHere's my answer.",
expectedThinking: "Let me think:\n- Point 1\n- Point 2",
expectedContent: "Here's my answer.",
hasThinking: true,
},
{
name: "tool_call_simple",
input: "I'll check the weather.<tool▁calls▁begin><tool▁call▁begin>get_weather<tool▁sep>{\"location\":\"Paris\"}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "I'll check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
hasThinking: false,
},
{
name: "multiple_tool_calls",
input: "Getting weather for both cities.<tool▁calls▁begin><tool▁call▁begin>get_weather<tool▁sep>{\"location\":\"Paris\"}<tool▁call▁end><tool▁call▁begin>get_weather<tool▁sep>{\"location\":\"London\"}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Getting weather for both cities.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
hasThinking: false,
},
{
name: "tool_output",
input: "Here's the weather: <tool▁output▁begin>Temperature: 22°C, Sunny<tool▁output▁end> Hope that helps!",
expectedContent: "Here's the weather: Temperature: 22°C, Sunny Hope that helps!",
hasThinking: false,
},
{
name: "complex_tool_arguments",
input: "Processing data.<tool▁calls▁begin><tool▁call▁begin>process_data<tool▁sep>{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Processing data.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []interface{}{"item1", "item2"},
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
},
},
},
},
hasThinking: false,
},
{
name: "thinking_with_tool_call", // technically this can't happen, but the parser can handle it
input: "Let me check the weather...</think>I'll get that for you.<tool▁calls▁begin><tool▁call▁begin>get_weather<tool▁sep>{\"location\":\"Paris\"}<tool▁call▁end><tool▁calls▁end>",
expectedThinking: "Let me check the weather...",
expectedContent: "I'll get that for you.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
hasThinking: true,
},
{
name: "empty_content",
input: "",
expectedContent: "",
hasThinking: false,
},
{
name: "only_thinking",
input: "Just thinking content</think>",
expectedThinking: "Just thinking content",
expectedContent: "",
hasThinking: true,
},
{
name: "multiple_tool_outputs",
input: "Results: <tool▁output▁begin>Paris: 22°C<tool▁output▁end> and <tool▁output▁begin>London: 18°C<tool▁output▁end>",
expectedContent: "Results: Paris: 22°C and London: 18°C",
hasThinking: false,
},
{
name: "unicode_content",
input: "مرحبا بالعالم! 你好世界! 🌍",
expectedContent: "مرحبا بالعالم! 你好世界! 🌍",
hasThinking: false,
},
{
name: "emoji_passthrough",
input: "Task completed ✅ 🎉",
expectedContent: "Task completed ✅ 🎉",
hasThinking: false,
},
{
name: "emoji_after_tool_call",
input: "I'll help you.<tool▁calls▁begin><tool▁call▁begin>get_weather<tool▁sep>{\"location\":\"Tokyo\"}<tool▁call▁end><tool▁calls▁end>完成 ✅",
expectedContent: "I'll help you.完成 ✅",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
},
},
},
hasThinking: false,
},
{
name: "newlines_and_whitespace",
input: "Line 1\n\nLine 3\t\tTabbed content",
expectedContent: "Line 1\n\nLine 3\t\tTabbed content",
hasThinking: false,
},
{
name: "thinking_with_unicode",
input: "我在思考这个问题...</think>答案是42。",
expectedThinking: "我在思考这个问题...",
expectedContent: "答案是42。",
hasThinking: true,
},
{
name: "tool_call_with_unicode_args",
input: "Searching for information.<tool▁calls▁begin><tool▁call▁begin>search<tool▁sep>{\"query\":\"北京天气\",\"language\":\"中文\"}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Searching for information.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: api.ToolCallFunctionArguments{
"query": "北京天气",
"language": "中文",
},
},
},
},
hasThinking: false,
},
{
name: "tool_output_with_unicode",
input: "天气信息: <tool▁output▁begin>北京: 25°C, 晴天<tool▁output▁end> 希望对您有帮助!",
expectedContent: "天气信息: 北京: 25°C, 晴天 希望对您有帮助!",
hasThinking: false,
},
{
name: "mixed_content_with_special_chars",
input: "Price: $100 & tax @ 10% = $110 <tool▁output▁begin>Total: $110<tool▁output▁end> (final)",
expectedContent: "Price: $100 & tax @ 10% = $110 Total: $110 (final)",
hasThinking: false,
},
{
name: "tool_call_with_special_chars",
input: "Processing data.<tool▁calls▁begin><tool▁call▁begin>execute_command<tool▁sep>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Processing data.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "execute_command",
Arguments: api.ToolCallFunctionArguments{
"command": "ls && echo \"done\"",
"path": "/home/user",
},
},
},
},
hasThinking: false,
},
{
name: "thinking_with_special_chars",
input: "Let me calculate: 2+2=4 & 3*3=9...</think>The results are correct!",
expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...",
expectedContent: "The results are correct!",
hasThinking: true,
},
{
name: "empty_tool_call_args",
input: "Pinging server.<tool▁calls▁begin><tool▁call▁begin>ping<tool▁sep>{}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Pinging server.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
hasThinking: false,
},
{
name: "empty_tool_output",
input: "Checking status: <tool▁output▁begin><tool▁output▁end> No output received.",
expectedContent: "Checking status: No output received.",
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
content, thinking, calls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestDeepSeekParser_Streaming(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
hasThinking bool
}{
{
name: "streaming_simple_content",
chunks: []string{"Hello, ", "how are ", "you?"},
expectedContent: "Hello, how are you?",
hasThinking: false,
},
{
name: "streaming_thinking",
chunks: []string{"I need to ", "think about this", "...</think>", "The answer is 42."},
expectedThinking: "I need to think about this...",
expectedContent: "The answer is 42.",
hasThinking: true,
},
{
name: "streaming_tool_call",
chunks: []string{"I'll check weather.", "<tool▁calls▁begin>", "<tool▁call▁begin>get_weather", "<tool▁sep>{\"location\":\"Paris\"}", "<tool▁call▁end><tool▁calls▁end>"},
expectedContent: "I'll check weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
hasThinking: false,
},
{
name: "streaming_thinking_with_partial_tag",
chunks: []string{"Thinking about this", "...</", "think>", "Done thinking."},
expectedThinking: "Thinking about this...",
expectedContent: "Done thinking.",
hasThinking: true,
},
{
name: "streaming_tool_output",
chunks: []string{"Weather info: ", "<tool▁output▁begin>", "25°C, Sunny", "<tool▁output▁end>", " Enjoy!"},
expectedContent: "Weather info: 25°C, Sunny Enjoy!",
hasThinking: false,
},
{
name: "streaming_with_split_tags",
chunks: []string{"Content before ", "<tool▁calls▁begin><tool▁call▁begin>test", "<tool▁sep>{}", "<tool▁call▁end><tool▁calls▁end>", " after"},
expectedContent: "Content before after",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
hasThinking: false,
},
{
name: "streaming_thinking_with_split_end_tag",
chunks: []string{"Thinking content", "</th", "ink>", "Regular content"},
expectedThinking: "Thinking content",
expectedContent: "Regular content",
hasThinking: true,
},
{
name: "streaming_unicode_content",
chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"},
expectedContent: "مرحبا بالعالم! 你好世界!",
hasThinking: false,
},
{
name: "streaming_multiple_tool_outputs",
chunks: []string{"Results: ", "<tool▁output▁begin>", "Paris: 22°C", "<tool▁output▁end>", " and ", "<tool▁output▁begin>", "London: 18°C", "<tool▁output▁end>"},
expectedContent: "Results: Paris: 22°C and London: 18°C",
hasThinking: false,
},
{
name: "streaming_tool_call_with_split_json",
chunks: []string{"Processing.", "<tool▁calls▁begin><tool▁call▁begin>calc<tool▁sep>{\"x\":", "42,\"y\":", "24}<tool▁call▁end><tool▁calls▁end>"},
expectedContent: "Processing.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calc",
Arguments: api.ToolCallFunctionArguments{
"x": float64(42),
"y": float64(24),
},
},
},
},
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
var allContent, allThinking string
var allCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, calls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
allContent += content
allThinking += thinking
allCalls = append(allCalls, calls...)
}
if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestDeepSeekParser_HasThinkingSupport(t *testing.T) {
tests := []struct {
name string
hasThinking bool
expectedSupport bool
}{
{
name: "thinking_enabled",
hasThinking: true,
expectedSupport: true,
},
{
name: "thinking_disabled",
hasThinking: false,
expectedSupport: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
if got := parser.HasThinkingSupport(); got != tt.expectedSupport {
t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport)
}
})
}
}
func TestDeepSeekParser_HasToolSupport(t *testing.T) {
parser := &DeepSeekParser{}
if !parser.HasToolSupport() {
t.Error("HasToolSupport() should return true")
}
}
func TestDeepSeekParser_Init(t *testing.T) {
parser := &DeepSeekParser{hasThinkingSupport: true}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "test_tool",
},
},
}
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
if diff := cmp.Diff(tools, returnedTools); diff != "" {
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
}
// Test initial state is set to thinking when enabled
if parser.state != DeepSeekCollectingThinking {
t.Errorf("Expected initial state to be DeepSeekCollectingThinking, got %v", parser.state)
}
}
func TestDeepSeekParser_parseToolCallContent(t *testing.T) {
tests := []struct {
name string
content string
expected api.ToolCall
expectError bool
}{
{
name: "valid_tool_call",
content: "get_weather<tool▁sep>{\"location\":\"Paris\"}",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
{
name: "complex_arguments",
content: "process_data<tool▁sep>{\"items\":[\"a\",\"b\"],\"config\":{\"enabled\":true}}",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true},
},
},
},
},
{
name: "empty_arguments",
content: "ping<tool▁sep>{}",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "ping",
Arguments: api.ToolCallFunctionArguments{},
},
},
},
{
name: "unicode_in_tool_name",
content: "获取天气<tool▁sep>{\"城市\":\"北京\"}",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: api.ToolCallFunctionArguments{
"城市": "北京",
},
},
},
},
{
name: "special_chars_in_arguments",
content: "execute<tool▁sep>{\"command\":\"ls && echo \\\"done\\\"\",\"path\":\"/home/user\"}",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "execute",
Arguments: api.ToolCallFunctionArguments{
"command": "ls && echo \"done\"",
"path": "/home/user",
},
},
},
},
{
name: "numeric_arguments",
content: "calculate<tool▁sep>{\"x\":3.14,\"y\":42,\"enabled\":true}",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: api.ToolCallFunctionArguments{
"x": 3.14,
"y": float64(42),
"enabled": true,
},
},
},
},
{
name: "invalid_format_no_separator",
content: "get_weather{\"location\":\"Paris\"}",
expectError: true,
},
{
name: "invalid_json",
content: "get_weather<tool▁sep>{invalid json}",
expectError: true,
},
{
name: "empty_tool_name",
content: "<tool▁sep>{\"arg\":\"value\"}",
expectError: false, // This should work, just empty name
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "",
Arguments: api.ToolCallFunctionArguments{
"arg": "value",
},
},
},
},
{
name: "missing_json_part",
content: "tool_name<tool▁sep>",
expectError: true,
},
}
parser := &DeepSeekParser{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parser.parseToolCallContent(tt.content)
if tt.expectError {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestDeepSeekParser_EdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
hasThinking bool
}{
{
name: "nested_think_tags_in_thinking",
input: "Outer thinking <think>inner</think> content</think>Final content",
expectedThinking: "Outer thinking <think>inner",
expectedContent: "content</think>Final content",
hasThinking: true,
},
{
name: "multiple_think_close_tags",
input: "First thought</think>Second thought</think>Final content",
expectedThinking: "First thought",
expectedContent: "Second thought</think>Final content",
hasThinking: true,
},
{
name: "empty_thinking_content",
input: "</think>Just content",
expectedThinking: "",
expectedContent: "Just content",
hasThinking: true,
},
{
name: "thinking_disabled_with_think_tags",
input: "Some content</think>More content",
expectedContent: "Some content</think>More content",
hasThinking: false,
},
{
name: "malformed_tool_call_missing_sep",
input: "Testing.<tool▁calls▁begin><tool▁call▁begin>bad_tool{\"arg\":\"value\"}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Testing.",
hasThinking: false,
},
{
name: "malformed_tool_call_invalid_json",
input: "Testing.<tool▁calls▁begin><tool▁call▁begin>bad_tool<tool▁sep>{invalid json}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "Testing.",
hasThinking: false,
},
{
name: "partial_tool_tag_at_end",
input: "Content with partial <tool▁calls▁",
expectedContent: "Content with partial <tool▁calls▁",
hasThinking: false,
},
{
name: "partial_think_tag_at_end",
input: "Thinking content</th",
expectedContent: "Thinking content</th",
hasThinking: false,
},
{
name: "partial_think_tag_at_end_with_thinking",
input: "Thinking content</th",
expectedThinking: "Thinking content",
expectedContent: "",
hasThinking: true,
},
{
name: "whitespace_only_content",
input: " \n\t ",
expectedContent: " \n\t ",
hasThinking: false,
},
{
name: "tool_output_with_newlines",
input: "Output:\n<tool▁output▁begin>Line 1\nLine 2\nLine 3<tool▁output▁end>\nDone.",
expectedContent: "Output:\nLine 1\nLine 2\nLine 3\nDone.",
hasThinking: false,
},
{
name: "consecutive_tool_calls",
input: "First.<tool▁calls▁begin><tool▁call▁begin>tool1<tool▁sep>{}<tool▁call▁end><tool▁calls▁end>Second.<tool▁calls▁begin><tool▁call▁begin>tool2<tool▁sep>{}<tool▁call▁end><tool▁calls▁end>",
expectedContent: "First.",
hasThinking: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &DeepSeekParser{hasThinkingSupport: tt.hasThinking}
parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
content, thinking, _, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("Content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -58,8 +58,6 @@ func ParserForName(name string) Parser {
return harmony.NewHarmonyMessageHandler()
case "cogito":
return &CogitoParser{}
case "deepseek":
return &DeepSeekParser{hasThinkingSupport: true}
case "olmo3":
return &Olmo3Parser{}
case "olmo3-think":

View File

@@ -10,15 +10,12 @@ import (
)
const (
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
olmo31DefaultSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai. "
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
)
type Olmo3Renderer struct {
UseExtendedSystemMessage bool
}
type Olmo3Renderer struct{}
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
@@ -54,11 +51,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
} else {
// Default system message - single newline after "system"
sb.WriteString("<|im_start|>system\n")
if r.UseExtendedSystemMessage {
sb.WriteString(olmo31DefaultSystemMessage)
} else {
sb.WriteString(olmo3DefaultSystemMessage)
}
sb.WriteString(olmo3DefaultSystemMessage)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
@@ -147,7 +140,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
}
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n")
sb.WriteString("<|im_start|>assistant\n\n")
}
return sb.String(), nil

View File

@@ -24,7 +24,7 @@ func TestOlmo3Renderer(t *testing.T) {
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "with system message no tools",
@@ -36,7 +36,7 @@ func TestOlmo3Renderer(t *testing.T) {
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "with system message and tools",
@@ -64,7 +64,7 @@ func TestOlmo3Renderer(t *testing.T) {
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "default system with tools - includes function instruction",
@@ -93,7 +93,7 @@ func TestOlmo3Renderer(t *testing.T) {
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "assistant with tool calls - function call syntax",
@@ -141,7 +141,7 @@ func TestOlmo3Renderer(t *testing.T) {
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "multi-turn conversation",
@@ -159,7 +159,7 @@ func TestOlmo3Renderer(t *testing.T) {
"Hi there!<|im_end|>\n" +
"<|im_start|>user\n" +
"How are you?<|im_end|>\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "parallel tool calls - newline separated",
@@ -214,7 +214,7 @@ func TestOlmo3Renderer(t *testing.T) {
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 55}<|im_end|>` + "\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "tool call with multiple arguments",
@@ -259,7 +259,7 @@ func TestOlmo3Renderer(t *testing.T) {
"Book a flight<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
"<|im_start|>assistant\n",
"<|im_start|>assistant\n\n",
},
{
name: "assistant prefill - no generation prompt",

View File

@@ -1,31 +1,31 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type Olmo3ThinkVariant int
const (
// Olmo3Think32B is for allenai/Olmo-3-32B-Think
Olmo3Think32B Olmo3ThinkVariant = iota
// Olmo31Think is for allenai/Olmo-3-7B-Think and allenai/Olmo-3.1-32B-Think (includes model info)
Olmo31Think
olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai."
olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions."
)
const (
olmo3ThinkFunctionsSuffix = " You do not currently have access to any functions. <functions></functions>"
olmo3Think32BSystemMessage = "You are a helpful AI assistant."
olmo31ThinkSystemMessage = "You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai."
)
type Olmo3ThinkRenderer struct{}
type Olmo3ThinkRenderer struct {
Variant Olmo3ThinkVariant
type olmo3ThinkToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function olmo3ThinkToolCallFunc `json:"function"`
}
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, _ []api.Tool, _ *api.ThinkValue) (string, error) {
type olmo3ThinkToolCallFunc struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
var systemMessage *api.Message
@@ -37,31 +37,34 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, _ []api.Tool, _ *api
}
continue
}
// Skip tool messages - Think models don't support tools
if message.Role == "tool" {
continue
}
filteredMessages = append(filteredMessages, message)
}
sb.WriteString("<|im_start|>system\n")
systemContent := olmo3ThinkDefaultSystemMessage
if systemMessage != nil {
sb.WriteString(systemMessage.Content)
sb.WriteString(olmo3ThinkFunctionsSuffix)
} else {
// Default system message varies by variant
switch r.Variant {
case Olmo3Think32B:
sb.WriteString(olmo3Think32BSystemMessage)
default: // Olmo3Think7B, Olmo31Think use same template - diverges from HF but confirmed difference from team
sb.WriteString(olmo31ThinkSystemMessage)
}
systemContent = systemMessage.Content
}
sb.WriteString("<|im_start|>system\n")
sb.WriteString(systemContent)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
if err != nil {
return "", err
}
sb.WriteString(" <functions>")
sb.WriteString(string(functionsJSON))
sb.WriteString("</functions>")
} else {
sb.WriteString(olmo3ThinkNoFunctionsMessage)
sb.WriteString(" <functions></functions>")
}
sb.WriteString("<|im_end|>\n")
for _, message := range filteredMessages {
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
switch message.Role {
case "user":
sb.WriteString("<|im_start|>user\n")
@@ -70,15 +73,58 @@ func (r *Olmo3ThinkRenderer) Render(messages []api.Message, _ []api.Tool, _ *api
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls))
for j, tc := range message.ToolCalls {
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
return "", err
}
toolCalls[j] = olmo3ThinkToolCall{
ID: tc.ID,
Type: "function",
Function: olmo3ThinkToolCallFunc{
Name: tc.Function.Name,
Arguments: string(argsJSON),
},
}
}
toolCallsJSON, err := marshalWithSpaces(toolCalls)
if err != nil {
return "", err
}
sb.WriteString("<function_calls>")
sb.WriteString(string(toolCallsJSON))
sb.WriteString("</function_calls>")
}
if !lastMessage {
sb.WriteString("<|im_end|>\n")
}
case "tool":
sb.WriteString("<|im_start|>environment\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
}
}
// Always add generation prompt with <think> tag for thinking models
sb.WriteString("<|im_start|>assistant\n<think>")
needsGenerationPrompt := true
if len(filteredMessages) > 0 {
lastMsg := filteredMessages[len(filteredMessages)-1]
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
needsGenerationPrompt = false
}
}
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n<think>")
}
return sb.String(), nil
}

View File

@@ -11,27 +11,24 @@ import (
func TestOlmo3ThinkRenderer(t *testing.T) {
tests := []struct {
name string
variant Olmo3ThinkVariant
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "7b_basic_without_system",
variant: Olmo31Think,
name: "basic without system - adds default system",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "7b_with_custom_system",
variant: Olmo31Think,
name: "with system message no tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
@@ -44,9 +41,9 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
"<think>",
},
{
name: "7b_tools_ignored",
variant: Olmo31Think,
name: "with system message and tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather?"},
},
tools: []api.Tool{
@@ -55,20 +52,27 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "7b_tool_calls_and_tool_messages_ignored",
variant: Olmo31Think,
name: "assistant with tool calls",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather in SF?"},
{
Role: "assistant",
@@ -77,33 +81,53 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
Name: "get_weather",
Arguments: map[string]any{
"location": "San Francisco",
},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`},
},
expected: "<|im_start|>system\n" +
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather in SF?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Let me check the weather.<|im_end|>\n" +
`Let me check the weather.<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "7b_multi_turn_conversation",
variant: Olmo31Think,
name: "multi-turn conversation",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
expected: "<|im_start|>system\n" +
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello<|im_end|>\n" +
"<|im_start|>assistant\n" +
@@ -114,56 +138,73 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
"<think>",
},
{
name: "32b_basic_without_system",
variant: Olmo3Think32B,
name: "parallel tool calls",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "Get weather in SF and NYC"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful AI assistant.<|im_end|>\n" +
`You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. <functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"Get weather in SF and NYC<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 55}<|im_end|>` + "\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "32b_with_custom_system_gets_suffix",
variant: Olmo3Think32B,
name: "assistant message only content no tool calls",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
{Role: "user", Content: "Tell me a joke"},
{Role: "assistant", Content: "Why did the chicken cross the road?"},
{Role: "user", Content: "I don't know, why?"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"Tell me a joke<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "31_basic_without_system",
variant: Olmo31Think,
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are Olmo, a helpful AI assistant built by Ai2. Your date cutoff is December 2024, and your model weights are available at https://huggingface.co/allenai.<|im_end|>\n" +
"Why did the chicken cross the road?<|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "31_with_custom_system_gets_suffix",
variant: Olmo31Think,
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"I don't know, why?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
@@ -171,7 +212,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := (&Olmo3ThinkRenderer{Variant: tt.variant}).Render(tt.msgs, tt.tools, nil)
rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -60,18 +60,10 @@ func rendererForName(name string) Renderer {
renderer := &CogitoRenderer{isThinking: true}
return renderer
case "olmo3":
renderer := &Olmo3Renderer{UseExtendedSystemMessage: false}
return renderer
case "olmo3.1":
renderer := &Olmo3Renderer{UseExtendedSystemMessage: true}
renderer := &Olmo3Renderer{}
return renderer
case "olmo3-think":
// Used for Olmo-3-7B-Think and Olmo-3.1-32B-Think (same template)
renderer := &Olmo3ThinkRenderer{Variant: Olmo31Think}
return renderer
case "olmo3-32b-think":
// Used for Olmo-3-32B-Think
renderer := &Olmo3ThinkRenderer{Variant: Olmo3Think32B}
renderer := &Olmo3ThinkRenderer{}
return renderer
default:
return nil

View File

@@ -26,7 +26,6 @@ import (
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/runner/common"
)
@@ -833,7 +832,7 @@ func (s *Server) loadModel(
ppath string,
kvSize int,
kvCacheType string,
flashAttention ml.FlashAttentionType,
flashAttention bool,
threads int,
multiUserCache bool,
) {