Compare commits

..

25 Commits

Author SHA1 Message Date
Eva Ho
fc1f10cd0b fix test 2025-12-16 13:28:15 -05:00
Eva Ho
0e02a037b4 clean up 2025-12-15 17:49:24 -05:00
Eva Ho
4c2c51761e adding test 2025-12-15 17:46:26 -05:00
Eva Ho
c02bf766fd fix: windows app blank ui issue 2025-12-15 17:19:26 -05:00
Eva H
8dbc9e7b68 app/ui: handle unspecified bind addresses and wait for server in ollama proxy (#13159) 2025-12-15 13:33:09 -05:00
Daniel Hiltgen
abe67acf8a Revert "Enable Ollama engine by default" (#13481)
This reverts commit 56f754f46b.
2025-12-15 09:55:45 -08:00
Jeffrey Morgan
4ff8a691bc model: default gemma 3 rope scale to 1.0, apply corrections based on layer counts (#13453) 2025-12-12 17:51:56 -08:00
Jeffrey Morgan
1b308e1d2a model: fix global layer rope scale values for gemma 3 (#13452) 2025-12-12 16:29:01 -08:00
Daniel Hiltgen
bd6c1d6b49 flash attn: add auto mode for llama engine (#13052)
* flash attn: add auto mode for llama engine

If the user does not specify fa in the environment, use auto-mode.

* review comments

* ensure kv cache quantized types have FA explicitly enabled

additional review comments
2025-12-12 13:27:19 -08:00
Jeffrey Morgan
3af5d3b738 model: force rope factor 1.0 for Gemma 3 (#13445) 2025-12-12 13:27:08 -08:00
Daniel Hiltgen
7730895158 Enable Ollama engine by default (#13443)
This changes the default behavior to use the Ollama engine for supported
models, while retaining the ability to disable the Ollama engine and
fall back to the Llama engine.  Models in the OllamaEngineRequired list
will always run on the Ollama engine.
2025-12-12 11:48:43 -08:00
Eva H
de9ecfd01c tidy up lint warnings on windows (#13430) 2025-12-12 11:43:35 -05:00
Eva H
95fdd8d619 fix: select and update models folder in settings (#13412) 2025-12-12 11:09:37 -05:00
Devon Rifkin
9f7822851c docs: add docs for v1/responses and rework openai compat section (#13416)
* docs: add docs for v1/responses and rework openai compat section

I reworked the examples to be separated by topic and to be fully
runnable (i.e., they now log output instead of just suggesting how a
call might be made).

We now use `<CodeGroup>`s so that each example has a dropdown on the
docs site for users to choose, which makes the examples a lot more
digestible (since you only see approx 1/3 of the code you used to).

I also added a new tool to extract code examples into files so that it's
easier to actually run them and check that they work.

## Example

```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```

Output:

```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368

  - 01_basic.py
  - 01_basic.js
  - 01_basic.sh
  - 02_responses.py
  - 02_responses.js
  - 02_responses.sh
  - 03_vision.py
  - 03_vision.js
  - 03_vision.sh

Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368

To run examples:

  cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
  npm install   # for JS examples

then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```

In the future we should consider actually running the examples in CI and
having some sort of acceptance test so we can automatically detect when
our examples break. So this is just a start in that direction.

* Update docs/api/openai-compatibility.mdx

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

* Update docs/api/openai-compatibility.mdx

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>

---------

Co-authored-by: Parth Sareen <parth.sareen@ollama.com>
2025-12-11 17:39:40 -08:00
Parth Sareen
9b2035d194 openai: add tool call appending to previous assistant message (#13434)
* openai: add tool call appending to previous asst message

* add tests for thinking appending
2025-12-11 17:30:12 -08:00
Alexander Gusak
93d45d7a04 docs: fix link to modelfile.mdx (#13220) 2025-12-11 16:14:45 -08:00
JJ
709f842457 Update README.md (#13373)
Correct Markdown syntax for Swollama GitHub and DocC documentation links
2025-12-11 16:08:57 -08:00
Jeffrey Morgan
2dfb74410d model: fix rotary embeddings for ministral 3 (#13432) 2025-12-11 16:02:05 -08:00
Devon Rifkin
1eb5e75972 openai: add v1/responses support (#13351)
Only supporting the stateless part of the API.

Doc updates to come once this is shipped.

Closes: #9659
2025-12-11 15:37:10 -08:00
nicole pardal
3475d915cb embeddings: modified batch size (#13429)
This PR detects embedding models and sets batch_size = context_size so the full input fits in a single batch.
Previously, if batch size was smaller than the input, tokens could be split across batches and cause a SIGTRAP crash.
This change ensures all tokens stay in one batch and prevents crashes.
Fixes: #12938 #13054

Co-authored-by: Jesse Gross <jesse@ollama.com>
2025-12-11 15:36:31 -08:00
Jeffrey Morgan
48e78e9be1 template: add yesterdayDate helper function (#13431) 2025-12-11 14:47:55 -08:00
Jeffrey Morgan
a838421ea3 model: conversion and hyperparameter fixes for ministral and devstral (#13424) 2025-12-11 13:04:00 -08:00
EasonLin
1c4e85b4df routes: add logprobs in tool calls (#13238) 2025-12-10 17:28:41 -08:00
Eloi Torrents
dac4f17fea cmd/bench: fix binary name in README (#13276) 2025-12-10 14:16:58 -08:00
Julia Scheaffer
56b8fb024c cmd/bench: fix options table in cmd/bench/README.md (#13216) 2025-12-10 14:07:48 -08:00
38 changed files with 4120 additions and 337 deletions

View File

@@ -555,7 +555,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama)
- [Swollama for Swift](https://github.com/guitaripod/Swollama) with [DocC](https://guitaripod.github.io/Swollama/documentation/swollama)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)

View File

@@ -347,7 +347,7 @@ type CreateProgressFunc func(ProgressResponse) error
// Create creates a model from a [Modelfile]. fn is a progress function that
// behaves similarly to other methods (see [Client.Pull]).
//
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.md
// [Modelfile]: https://github.com/ollama/ollama/blob/main/docs/modelfile.mdx
func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
var resp ProgressResponse

View File

@@ -191,13 +191,6 @@ func LaunchNewApp() {
C.launchApp(appName)
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
C.uiRequest(p)
}
func registerLaunchAgent(hasCompletedFirstRun bool) {
// Remove any stale Login Item registrations
C.unregisterSelfFromLoginItem()

View File

@@ -263,11 +263,6 @@ func createLoginShortcut() error {
return nil
}
// Send a request to the main app thread to load a UI page
func sendUIRequestMessage(path string) {
wintray.SendUIRequestMessage(path)
}
func LaunchNewApp() {
}

View File

@@ -169,37 +169,47 @@ DlgResult fileDlg(FileDlgParams* params) {
}
NSArray* urls = [panel URLs];
if(self->params->allowMultiple && [urls count] >= 1) {
if([urls count] == 0) {
return DLG_CANCEL;
}
if(self->params->allowMultiple) {
// For multiple files, we need to return all paths separated by null bytes
char* bufPtr = self->params->buf;
int remainingBuf = self->params->nbuf;
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
// Calculate total required buffer size first
int totalSize = 0;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
if(![url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX]) {
return DLG_URLFAIL;
}
totalSize += strlen(tempBuf) + 1; // +1 for null terminator
}
totalSize += 1; // Final null terminator
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
if(totalSize > self->params->nbuf) {
// Not enough buffer space
return DLG_URLFAIL;
}
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
// Now actually copy the paths (we know we have space)
bufPtr = self->params->buf;
for(NSURL* url in urls) {
char tempBuf[PATH_MAX];
[url getFileSystemRepresentation:tempBuf maxLength:PATH_MAX];
int pathLen = strlen(tempBuf);
strcpy(bufPtr, tempBuf);
bufPtr += pathLen + 1;
}
*bufPtr = '\0'; // Final null terminator
} else {
// Single file/directory selection - write path to buffer
NSURL* url = [urls firstObject];
if(![url getFileSystemRepresentation:self->params->buf maxLength:self->params->nbuf]) {
return DLG_URLFAIL;
}
}
return DLG_OK;

View File

@@ -15,7 +15,7 @@ const multiFileBufferSize = w32.MAX_PATH * 10
type WinDlgError int
func (e WinDlgError) Error() string {
return fmt.Sprintf("CommDlgExtendedError: %#x", e)
return fmt.Sprintf("CommDlgExtendedError: %#x", int(e))
}
func err() error {

View File

@@ -224,9 +224,7 @@ func (s *Server) cmd(ctx context.Context) (*exec.Cmd, error) {
if _, err := os.Stat(settings.Models); err == nil {
env["OLLAMA_MODELS"] = settings.Models
} else {
slog.Warn("models path not accessible, clearing models setting", "path", settings.Models, "err", err)
settings.Models = ""
s.store.SetSettings(settings)
slog.Warn("models path not accessible, using default", "path", settings.Models, "err", err)
}
}
if settings.ContextLength > 0 {

View File

@@ -7,7 +7,9 @@ import (
"embed"
"errors"
"io/fs"
"mime"
"net/http"
"path/filepath"
"strings"
"time"
)
@@ -15,6 +17,13 @@ import (
//go:embed app/dist
var appFS embed.FS
func init() {
mime.AddExtensionType(".js", "application/javascript; charset=utf-8")
mime.AddExtensionType(".css", "text/css; charset=utf-8")
mime.AddExtensionType(".woff2", "font/woff2")
mime.AddExtensionType(".svg", "image/svg+xml")
}
// appHandler returns an HTTP handler that serves the React SPA.
// It tries to serve real files first, then falls back to index.html for React Router.
func (s *Server) appHandler() http.Handler {
@@ -24,11 +33,19 @@ func (s *Server) appHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p := strings.TrimPrefix(r.URL.Path, "/")
if _, err := fsys.Open(p); err == nil {
// Serve the file directly
if file, err := fsys.Open(p); err == nil {
file.Close()
// Ensure proper Content-Type headers
if contentType := mime.TypeByExtension(filepath.Ext(p)); contentType != "" {
w.Header().Set("Content-Type", contentType)
}
fileServer.ServeHTTP(w, r)
return
}
// Fallback serve index.html for unknown paths so React Router works
data, err := fs.ReadFile(fsys, "index.html")
if err != nil {
@@ -39,6 +56,8 @@ func (s *Server) appHandler() http.Handler {
}
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
http.ServeContent(w, r, "index.html", time.Time{}, bytes.NewReader(data))
})
}

39
app/ui/app_test.go Normal file
View File

@@ -0,0 +1,39 @@
//go:build windows || darwin
package ui
import (
"io/fs"
"strings"
"testing"
)
// TestEmbeddedAssets verifies that the correct UI assets are embedded.
// This test will FAIL THE BUILD if wrong files are embedded.
func TestEmbeddedAssets(t *testing.T) {
fsys, err := fs.Sub(appFS, "app/dist")
if err != nil {
t.Fatal("app/dist not found in embedded filesystem - UI not built")
}
data, err := fs.ReadFile(fsys, "index.html")
if err != nil {
t.Fatal("index.html not found - run 'go generate' first")
}
html := string(data)
if strings.Contains(html, "/src/main.tsx") {
t.Fatal("Wrong index.html embedded: has /src/main.tsx (dev paths). The UI was not built. Run 'npm run build' first.")
}
if !strings.Contains(html, "/assets/index-") {
t.Fatal("Wrong index.html embedded: missing /assets/index-* (production paths). The UI was not built correctly.")
}
if _, err := fsys.Open("assets"); err != nil {
t.Fatal("assets/ directory not found - UI build incomplete")
}
t.Log("Embedded assets verified - UI built correctly")
}

View File

@@ -12,13 +12,13 @@ import (
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"os"
"runtime"
"runtime/debug"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -117,40 +117,66 @@ func (s *Server) log() *slog.Logger {
// ollamaProxy creates a reverse proxy handler to the Ollama server
func (s *Server) ollamaProxy() http.Handler {
ollamaHost := os.Getenv("OLLAMA_HOST")
if ollamaHost == "" {
ollamaHost = "http://127.0.0.1:11434"
}
var (
proxy http.Handler
proxyMu sync.Mutex
)
if !strings.HasPrefix(ollamaHost, "http://") && !strings.HasPrefix(ollamaHost, "https://") {
ollamaHost = "http://" + ollamaHost
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyMu.Lock()
p := proxy
proxyMu.Unlock()
target, err := url.Parse(ollamaHost)
if err != nil {
s.log().Error("failed to parse OLLAMA_HOST", "error", err, "host", ollamaHost)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "failed to configure proxy", http.StatusInternalServerError)
})
}
if p == nil {
proxyMu.Lock()
if proxy == nil {
var err error
for i := range 2 {
if i > 0 {
s.log().Warn("ollama server not ready, retrying", "attempt", i+1)
time.Sleep(1 * time.Second)
}
s.log().Info("configuring ollama proxy", "target", target.String())
err = WaitForServer(context.Background(), 10*time.Second)
if err == nil {
break
}
}
proxy := httputil.NewSingleHostReverseProxy(target)
if err != nil {
proxyMu.Unlock()
s.log().Error("ollama server not ready after retries", "error", err)
http.Error(w, "Ollama server is not ready", http.StatusServiceUnavailable)
return
}
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
target := envconfig.Host()
s.log().Info("configuring ollama proxy", "target", target.String())
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
newProxy := httputil.NewSingleHostReverseProxy(target)
return proxy
originalDirector := newProxy.Director
newProxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = target.Host
s.log().Debug("proxying request", "method", req.Method, "path", req.URL.Path, "target", target.Host)
}
newProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
s.log().Error("proxy error", "error", err, "path", r.URL.Path, "target", target.String())
http.Error(w, "proxy error: "+err.Error(), http.StatusBadGateway)
}
proxy = newProxy
p = newProxy
} else {
p = proxy
}
proxyMu.Unlock()
}
p.ServeHTTP(w, r)
})
}
type errHandlerFunc func(http.ResponseWriter, *http.Request) error

View File

@@ -158,16 +158,16 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
case uint32(UI_REQUEST_MSG_ID):
// Requests for the UI must always come from the main event thread
l := int(wParam)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l) //nolint:govet,gosec
t.app.UIRun(path)
case WM_COPYDATA:
// Handle URL scheme requests from other instances
if lParam != 0 {
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
if cds.DwData == 1 { // Our identifier for URL scheme messages
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam)) //nolint:govet,gosec
if cds.DwData == 1 { // Our identifier for URL scheme messages
// Convert the data back to string
data := make([]byte, cds.CbData)
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData]) //nolint:govet,gosec
urlScheme := string(data)
handleURLSchemeRequest(urlScheme)
lResult = 1 // Return non-zero to indicate success

View File

@@ -15,7 +15,7 @@ A Go-based command-line tool for benchmarking Ollama models with configurable pa
```
go build -o ollama-bench bench.go
./bench -model gpt-oss:20b -epochs 6 -format csv
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
@@ -29,31 +29,32 @@ go run bench.go -model gpt-oss:20b -epochs 3
### Basic Example
```
./bench -model gemma3 -epochs 6
./ollama-bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |

View File

@@ -182,6 +182,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{}
case "Ministral3ForCausalLM":
conv = &mistral3CausalModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":

View File

@@ -30,13 +30,15 @@ type mistral3Model struct {
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
@@ -50,6 +52,9 @@ type mistral3Model struct {
HeadDim uint32 `json:"head_dim"`
HiddenAct string `json:"hidden_act"`
RopeTheta float32 `json:"rope_theta"`
RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"vision_config"`
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
@@ -72,10 +77,22 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.TextModel.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.TextModel.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.TextModel.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.TextModel.RopeParameters.BetaSlow
if p.TextModel.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.TextModel.RopeParameters.Mscale
}
if p.TextModel.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.TextModel.RopeParameters.MscaleAllDim
}
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
if p.TextModel.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.TextModel.RopeParameters.Llama4ScalingBeta
}
// Vision configuration
@@ -88,7 +105,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["mistral3.vision.rope.freq_base"] = cmp.Or(p.VisionModel.RopeTheta, p.VisionModel.RopeParameters.RopeTheta)
// Multimodal configuration
kv["mistral3.image_token_index"] = p.ImageTokenIndex

View File

@@ -0,0 +1,181 @@
package convert
import (
"cmp"
"fmt"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistral3CausalModel struct {
ModelParameters
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
Llama4ScalingBeta *float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
Mscale *float32 `json:"mscale"`
MscaleAllDim *float32 `json:"mscale_all_dim"`
} `json:"rope_parameters"`
}
func (p *mistral3CausalModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral3"
kv["mistral3.vocab_size"] = p.VocabSize
// Text configuration
kv["mistral3.block_count"] = p.NumHiddenLayers
kv["mistral3.context_length"] = p.MaxPositionEmbeddings
kv["mistral3.embedding_length"] = p.HiddenSize
kv["mistral3.feed_forward_length"] = p.IntermediateSize
kv["mistral3.attention.head_count"] = p.NumAttentionHeads
kv["mistral3.attention.head_count_kv"] = p.NumKeyValueHeads
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["mistral3.attention.key_length"] = p.HeadDim
kv["mistral3.attention.value_length"] = p.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.HeadDim, p.HiddenSize/p.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.RopeTheta, p.RopeParameters.RopeTheta)
kv["mistral3.rope.scaling.factor"] = p.RopeParameters.Factor
kv["mistral3.rope.scaling.type"] = p.RopeParameters.RopeType
kv["mistral3.rope.scaling.beta_fast"] = p.RopeParameters.BetaFast
kv["mistral3.rope.scaling.beta_slow"] = p.RopeParameters.BetaSlow
if p.RopeParameters.Mscale != nil {
kv["mistral3.rope.scaling.mscale"] = *p.RopeParameters.Mscale
}
if p.RopeParameters.MscaleAllDim != nil {
kv["mistral3.rope.scaling.mscale_all_dim"] = *p.RopeParameters.MscaleAllDim
}
if p.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
if p.RopeParameters.Llama4ScalingBeta != nil {
kv["mistral3.rope.scaling_beta"] = *p.RopeParameters.Llama4ScalingBeta
}
return kv
}
func (p *mistral3CausalModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") {
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
strings.HasSuffix(t.Name(), ".attn_k.weight") {
t.SetRepacker(p.repack)
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistral3CausalModel) Replacements() []string {
return []string{
"model.norm", "output_norm",
"model.", "",
"layers", "blk",
"transformer.layers", "blk",
"vision_tower", "v",
"ln_pre", "encoder_norm",
"input_layernorm", "attn_norm",
"post_attention_layernorm", "ffn_norm",
"embed_tokens", "token_embd",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"mlp.down_proj", "ffn_down",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"ffn_norm", "ffn_norm",
"lm_head", "output",
}
}
func (p *mistral3CausalModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, ".attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, ".attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -50,7 +50,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
@@ -507,7 +507,7 @@ The `message` object has the following fields:
Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -1189,7 +1189,7 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
- `template`: (optional) the prompt template for the model
- `license`: (optional) a string or list of strings containing the license or licenses for the model
- `system`: (optional) a string containing the system prompt for the model
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.md#valid-parameters-and-values) for a list of parameters)
- `parameters`: (optional) a dictionary of parameters for the model (see [Modelfile](./modelfile.mdx#valid-parameters-and-values) for a list of parameters)
- `messages`: (optional) a list of message objects used to create a conversation
- `stream`: (optional) if `false` the response will be returned as a single response object, rather than a stream of objects
- `quantize` (optional): quantize a non-quantized (e.g. float16) model
@@ -1698,7 +1698,7 @@ Generate embeddings from a model
Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding
@@ -1817,7 +1817,7 @@ Generate embeddings from a model
Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,46 @@
# extract-examples
Extracts code examples from MDX files to a temp directory so you can run them.
## Usage
```shell
go run docs/tools/extract-examples/main.go <mdx-file>
```
## Example
```shell
go run docs/tools/extract-examples/main.go docs/api/openai-compatibility.mdx
```
Output:
```
Extracting code examples to: /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
- 01_basic.py
- 01_basic.js
- 01_basic.sh
- 02_responses.py
- 02_responses.js
- 02_responses.sh
- 03_vision.py
- 03_vision.js
- 03_vision.sh
Extracted 9 file(s) to /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
To run examples:
cd /var/folders/vq/wfm2g6k917d3ldzpjdxc8ph00000gn/T/mdx-examples-3271754368
npm install # for JS examples
then run individual files with `node file.js`, `python file.py`, `bash file.sh`
```
## How it works
- Parses MDX files looking for fenced code blocks with filenames (e.g., ` ```python basic.py `)
- Groups examples by their `<CodeGroup>` and prefixes filenames with `01_`, `02_`, etc.
- Writes all extracted files to a temp directory

View File

@@ -0,0 +1,137 @@
package main
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
)
func main() {
if len(os.Args) < 2 {
fmt.Fprintln(os.Stderr, "Usage: go run extract-examples.go <mdx-file>")
os.Exit(1)
}
mdxFile := os.Args[1]
f, err := os.Open(mdxFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
defer f.Close()
// Create temp directory
tempDir, err := os.MkdirTemp("", "mdx-examples-*")
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating temp dir: %v\n", err)
os.Exit(1)
}
fmt.Printf("Extracting code examples to: %s\n\n", tempDir)
// Patterns
codeBlockStart := regexp.MustCompile("^```([a-zA-Z0-9_-]+)\\s+([^\\s]+)$")
codeGroupStart := regexp.MustCompile("^<CodeGroup")
codeGroupEnd := regexp.MustCompile("^</CodeGroup>")
scanner := bufio.NewScanner(f)
inCodeBlock := false
inCodeGroup := false
var currentFile string
var content strings.Builder
count := 0
codeGroupNum := 0
for scanner.Scan() {
line := scanner.Text()
// Track CodeGroup boundaries
if codeGroupStart.MatchString(line) {
inCodeGroup = true
codeGroupNum++
continue
}
if codeGroupEnd.MatchString(line) {
inCodeGroup = false
continue
}
if inCodeBlock {
if line == "```" {
// End of code block - write file
if currentFile != "" {
outPath := filepath.Join(tempDir, currentFile)
if err := os.WriteFile(outPath, []byte(content.String()), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing %s: %v\n", currentFile, err)
} else {
fmt.Printf(" - %s\n", currentFile)
count++
}
}
inCodeBlock = false
currentFile = ""
content.Reset()
} else {
content.WriteString(line)
content.WriteString("\n")
}
} else {
if matches := codeBlockStart.FindStringSubmatch(line); matches != nil {
inCodeBlock = true
filename := matches[2]
// Prefix with CodeGroup number if inside a CodeGroup
if inCodeGroup {
currentFile = fmt.Sprintf("%02d_%s", codeGroupNum, filename)
} else {
currentFile = filename
}
content.Reset()
}
}
}
if err := scanner.Err(); err != nil {
fmt.Fprintf(os.Stderr, "Error reading file: %v\n", err)
os.Exit(1)
}
// Write package.json for JavaScript dependencies
packageJSON := `{
"name": "mdx-examples",
"type": "module",
"dependencies": {
"openai": "^4",
"ollama": "^0.5"
}
}
`
if err := os.WriteFile(filepath.Join(tempDir, "package.json"), []byte(packageJSON), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing package.json: %v\n", err)
}
// Write pyproject.toml for Python dependencies
pyprojectTOML := `[project]
name = "mdx-examples"
version = "0.0.0"
dependencies = [
"openai",
"ollama",
]
`
if err := os.WriteFile(filepath.Join(tempDir, "pyproject.toml"), []byte(pyprojectTOML), 0o644); err != nil {
fmt.Fprintf(os.Stderr, "Error writing pyproject.toml: %v\n", err)
}
fmt.Printf("\n")
fmt.Printf("Extracted %d file(s) to %s\n", count, tempDir)
fmt.Printf("\n")
fmt.Printf("To run examples:\n")
fmt.Printf("\n")
fmt.Printf(" cd %s\n npm install # for JS examples\n", tempDir)
fmt.Printf("\n")
fmt.Printf("then run individual files with `node file.js`, `python file.py`, `bash file.sh`\n")
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil"
"github.com/ollama/ollama/ml"
)
type GGML struct {
@@ -550,7 +551,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention ml.FlashAttentionType) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength()
@@ -791,7 +792,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention {
if useFlashAttention == ml.FlashAttentionEnabled {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
@@ -809,6 +810,14 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}
// KVCacheTypeIsQuantized checks if the requested cache type is a quantized type
func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool {
if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" {
return false
}
return true
}
// SupportsFlashAttention checks if the model supports flash attention
func (f GGML) SupportsFlashAttention() bool {
_, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]

View File

@@ -487,6 +487,63 @@ func TestEmbedTruncation(t *testing.T) {
}
}
// TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes.
func TestEmbedLargeInput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute)
defer mcancel()
// Test with progressively larger inputs
testCases := []struct {
name string
inputWords int
}{
{"medium_input_256_words", 256},
{"large_input_512_words", 512},
{"very_large_input_800_words", 800},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
words := make([]string, tc.inputWords)
for i := range words {
words[i] = "word"
}
input := strings.Join(words, " ")
req := api.EmbedRequest{
Model: model,
Input: input,
KeepAlive: &api.Duration{Duration: 30 * time.Second},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatalf("embedding failed for %d words: %v", tc.inputWords, err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) == 0 {
t.Fatal("expected non-empty embedding")
}
t.Logf("Successfully embedded %d words (%d tokens)", tc.inputWords, res.PromptEvalCount)
})
}
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler

View File

@@ -118,18 +118,22 @@ type ContextParams struct {
c C.struct_llama_context_params
}
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention ml.FlashAttentionType, kvCacheType string) ContextParams {
params := C.llama_context_default_params()
params.n_ctx = C.uint(numCtx)
params.n_batch = C.uint(batchSize)
params.n_batch = C.uint(batchSize * numSeqMax)
params.n_ubatch = C.uint(batchSize)
params.n_seq_max = C.uint(numSeqMax)
params.n_threads = C.int(threads)
params.n_threads_batch = params.n_threads
params.embeddings = C.bool(true)
if flashAttention {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_ENABLED
} else {
params.flash_attn_type = C.LLAMA_FLASH_ATTN_TYPE_DISABLED
switch flashAttention {
case ml.FlashAttentionEnabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_ENABLED)
case ml.FlashAttentionDisabled:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_DISABLED)
case ml.FlashAttentionAuto:
params.flash_attn_type = int32(C.LLAMA_FLASH_ATTN_TYPE_AUTO)
}
params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))

View File

@@ -188,6 +188,11 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
if len(projectors) > 0 && llamaModel != nil {
loadRequest.ProjectorPath = projectors[0]
}
// Determine if the user has forced FA on or off
faUserSet := false
if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) {
faUserSet = true
}
fa := envconfig.FlashAttention(f.FlashAttention())
@@ -205,19 +210,51 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = true
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
if textProcessor == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
flashAttention = ml.FlashAttentionEnabled
} else {
flashAttention = ml.FlashAttentionDisabled
}
}
if kvct != "" {
if f.KVCacheTypeIsQuantized(kvct) {
if flashAttention != ml.FlashAttentionEnabled {
slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct)
loadRequest.KvCacheType = ""
} else if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
} else {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct)
}
}
}
loadRequest.FlashAttention = flashAttention
} else {
// For Ollama engine, use our SupportsFlashAttention logic
if fa {
slog.Info("enabling flash attention")
loadRequest.FlashAttention = ml.FlashAttentionEnabled
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
}
} else if kvct != "" && kvct != "f16" {
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
}
gpuLibs := ml.LibraryPaths(gpus)
@@ -435,7 +472,7 @@ type LoadRequest struct {
LoraPath []string
Parallel int
BatchSize int
FlashAttention bool
FlashAttention ml.FlashAttentionType
KvSize int
KvCacheType string
NumThreads int
@@ -474,6 +511,13 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers)
}
// Check if embedding model and adjust batch size accordingly
_, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())]
if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx {
s.loadRequest.BatchSize = s.options.NumCtx
slog.Info("embedding model detected, setting batch size to context length", "batch_size", s.loadRequest.BatchSize)
}
kv, graphPartialOffload, graphFullOffload := s.ggml.GraphSize(uint64(s.options.NumCtx), uint64(s.loadRequest.BatchSize),
s.loadRequest.Parallel, s.loadRequest.KvCacheType, s.loadRequest.FlashAttention)

View File

@@ -433,3 +433,111 @@ func ChatMiddleware() gin.HandlerFunc {
c.Next()
}
}
type ResponsesWriter struct {
BaseWriter
converter *openai.ResponsesStreamConverter
model string
stream bool
responseID string
itemID string
}
func (w *ResponsesWriter) writeEvent(eventType string, data any) error {
d, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
if err != nil {
return err
}
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
return nil
}
func (w *ResponsesWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
if err := json.Unmarshal(data, &chatResponse); err != nil {
return 0, err
}
if w.stream {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
}
}
return len(data), nil
}
// Non-streaming response
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := openai.ToResponse(w.model, w.responseID, w.itemID, chatResponse)
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
func (w *ResponsesWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
chatReq, err := openai.FromResponsesRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
// Check if client requested streaming (defaults to false)
streamRequested := req.Stream != nil && *req.Stream
// Pass streaming preference to the underlying chat request
chatReq.Stream = &streamRequested
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
responseID := fmt.Sprintf("resp_%d", rand.Intn(999999))
itemID := fmt.Sprintf("msg_%d", rand.Intn(999999))
w := &ResponsesWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
converter: openai.NewResponsesStreamConverter(responseID, itemID, req.Model),
model: req.Model,
stream: streamRequested,
responseID: responseID,
itemID: itemID,
}
// Set headers based on streaming mode
if streamRequested {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
}
c.Writer = w
c.Next()
}
}

View File

@@ -74,7 +74,7 @@ type BackendParams struct {
GPULayers GPULayersList
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
FlashAttention FlashAttentionType
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))

View File

@@ -109,7 +109,7 @@ type Backend struct {
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
btDeviceMemory map[C.ggml_backend_buffer_type_t]*ml.DeviceMemory
flashAttention bool
flashAttention ml.FlashAttentionType
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
maxGraphNodes int
@@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention {
if b.flashAttention == ml.FlashAttentionEnabled {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 256, PermutedV: true}
@@ -1676,7 +1676,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
if t.b.flashAttention {
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)

View File

@@ -492,6 +492,32 @@ func FlashAttentionSupported(l []DeviceInfo) bool {
return true
}
type FlashAttentionType int32
const (
// Aligned with llama_flash_attn_type
FlashAttentionAuto FlashAttentionType = -1
FlashAttentionDisabled FlashAttentionType = 0
FlashAttentionEnabled FlashAttentionType = 1
)
func (f FlashAttentionType) LogValue() slog.Value {
return slog.AnyValue(f.String())
}
func (f FlashAttentionType) String() string {
switch f {
case FlashAttentionAuto:
return "Auto"
case FlashAttentionDisabled:
return "Disabled"
case FlashAttentionEnabled:
return "Enabled"
default:
return "unknown"
}
}
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variables
// Set mustFilter true to enable filtering of CUDA devices

View File

@@ -2,7 +2,6 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -13,25 +12,26 @@ import (
)
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
hiddenSize, contextLength, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
largeModelScaling bool
slidingWindow uint32
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
}
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base, scale float32) ml.Tensor {
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(scale))))
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOriginalContext),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
@@ -41,7 +41,7 @@ func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positi
)
}
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./scale, ropeOpts...)
}
type TextModel struct {
@@ -55,6 +55,9 @@ type TextModel struct {
const (
gemmaGlobalCacheCount = 6
gemma1BLayerCount = 26
gemma4BLayerCount = 34
gemma12BLayerCount = 48
gemma27BLayerCount = 62
)
@@ -70,6 +73,7 @@ func newTextModel(c fs.Config) *TextModel {
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
contextLength: int(c.Uint("context_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
@@ -77,6 +81,7 @@ func newTextModel(c fs.Config) *TextModel {
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindow: c.Uint("attention.sliding_window"),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
@@ -88,14 +93,20 @@ func newTextModel(c fs.Config) *TextModel {
},
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0
// Apply corrections for older versions of the Gemma 3 models
// by looking at whether they use sliding window attention and
// based on their layer counts.
if m.TextConfig.slidingWindow < uint32(m.TextConfig.contextLength) {
switch numBlocks {
case gemma1BLayerCount:
// The 1B model has final logit softcapping set to 30.0
// but it should be 0.0
m.TextConfig.finalLogitSoftcap = 0.0
case gemma4BLayerCount, gemma12BLayerCount, gemma27BLayerCount:
// The 4B, 12B, and 27B models have rope scale unset
// but it shuold be set to 8.0
m.TextConfig.ropeScale = 8.0
}
}
if numBlocks == gemma27BLayerCount {
@@ -114,31 +125,31 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
func (opts *TextConfig) ropeValuesForLayer(layer int) (base float32, scale float32) {
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
return opts.ropeLocalBase
return opts.ropeLocalBase, 1.0
}
// Standard Gemma3: only every n-th layer is global,
// where n = gemmaGlobalCacheCount, otherwise use
// the local rope base
if (layer+1)%gemmaGlobalCacheCount > 0 {
return opts.ropeLocalBase
return opts.ropeLocalBase, 1.0
}
// default to global rope base
return opts.ropeBase
return opts.ropeBase, opts.ropeScale
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeBase := opts.ropeBaseForLayer(layer)
ropeBase, ropeScale := opts.ropeValuesForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase, ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -149,7 +160,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase, ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -162,7 +173,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
ropeBase, ropeScale := m.TextConfig.ropeValuesForLayer(layer)
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase, ropeScale), nil
}
type TextMLP struct {

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -17,10 +18,30 @@ type TextOptions struct {
eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int
ropeScalingBeta float32
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
ropeMscale float32
ropeMscaleAllDim float32
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
var ropeOpts []func(*rope.Options)
if o.ropeType == "yarn" {
if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0)))
}
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@@ -150,9 +171,15 @@ func newTextModel(c fs.Config) *TextModel {
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
ropeScale: c.Float("rope.scaling.factor", 1.0),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"),
ropeScalingBeta: c.Float("rope.scaling_beta", 0.1),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 32.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeType: c.String("rope.scaling.type"),
ropeMscale: c.Float("rope.scaling.mscale"),
ropeMscaleAllDim: c.Float("rope.scaling.mscale_all_dim"),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1),
},
}
}

View File

@@ -487,29 +487,9 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
}
types := []string{"jpeg", "jpg", "png", "webp"}
valid := false
// support blank mime type to match api/chat taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
valid = true
}
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, errors.New("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
img, err := decodeImageURL(url)
if err != nil {
return nil, errors.New("invalid message format")
return nil, err
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
@@ -648,6 +628,35 @@ func nameFromToolCallID(messages []Message, toolCallID string) string {
return ""
}
// decodeImageURL decodes a base64 data URI into raw image bytes.
func decodeImageURL(url string) (api.ImageData, error) {
types := []string{"jpeg", "jpg", "png", "webp"}
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
if strings.HasPrefix(url, "data:;base64,") {
url = strings.TrimPrefix(url, "data:;base64,")
} else {
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, errors.New("invalid image input")
}
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, errors.New("invalid image input")
}
return img, nil
}
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
apiToolCalls := make([]api.ToolCall, len(toolCalls))

1015
openai/responses.go Normal file

File diff suppressed because it is too large Load Diff

1842
openai/responses_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/runner/common"
)
@@ -832,7 +833,7 @@ func (s *Server) loadModel(
ppath string,
kvSize int,
kvCacheType string,
flashAttention bool,
flashAttention ml.FlashAttentionType,
threads int,
multiUserCache bool,
) {
@@ -842,7 +843,7 @@ func (s *Server) loadModel(
panic(err)
}
ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
ctxParams := llama.NewContextParams(kvSize, s.batchSize, s.parallel, threads, flashAttention, kvCacheType)
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
if err != nil {
panic(err)

View File

@@ -1203,16 +1203,22 @@ func (s *Server) allocModel(
return errors.New("loras are not yet implemented")
}
if s.model.Config().Cache == nil {
if parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
if s.batchSize < kvSize {
s.batchSize = kvSize
slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize)
}
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
if err != nil {
return err
}
if !s.cache.enabled && parallel > 1 {
parallel = 1
slog.Warn("model does not support caching, disabling parallel processing")
}
s.parallel = parallel
s.seqs = make([]*Sequence, s.parallel)
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))

View File

@@ -1532,6 +1532,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
if rc != nil {
// wrap old with new
@@ -2195,7 +2196,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
@@ -2235,8 +2236,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
// don't return, fall through to send
} else {
// Send logprobs while content is being buffered by the parser for tool calls
if len(res.Logprobs) > 0 && !r.Done {
logprobRes := res
logprobRes.Message.Content = ""
logprobRes.Message.ToolCalls = nil
ch <- logprobRes
}
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
@@ -2385,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
}
return msgs
}

View File

@@ -708,6 +708,95 @@ func TestGenerateChat(t *testing.T) {
}
})
t.Run("messages with tools and logprobs (streaming)", func(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
}
var wg sync.WaitGroup
wg.Add(1)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
// Simulate a response where logprobs are sent while the tool call is being buffered
responses := []llm.CompletionResponse{
{
Content: `{ "name": "get_weather"`,
Done: false,
Logprobs: []llm.Logprob{{}},
},
{
Content: `,"arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: false,
Logprobs: []llm.Logprob{{}},
},
{
Content: ``,
Done: true,
DoneReason: llm.DoneReasonStop,
Logprobs: nil,
},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
time.Sleep(10 * time.Millisecond)
}
}
return nil
}
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Weather?"},
},
Tools: tools,
Stream: &stream,
})
wg.Wait()
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
decoder := json.NewDecoder(w.Body)
var totalLogprobs int
for {
var resp api.ChatResponse
if err := decoder.Decode(&resp); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
totalLogprobs += len(resp.Logprobs)
}
expectedLogprobs := 2
if totalLogprobs != expectedLogprobs {
t.Errorf("expected %d logprobs, got %d", expectedLogprobs, totalLogprobs)
}
})
t.Run("status error non-streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{

View File

@@ -127,6 +127,9 @@ var funcs = template.FuncMap{
// Default format is YYYY-MM-DD
return time.Now().Format("2006-01-02")
},
"yesterdayDate": func(args ...string) string {
return time.Now().AddDate(0, 0, -1).Format("2006-01-02")
},
"toTypeScriptType": func(v any) string {
if param, ok := v.(api.ToolProperty); ok {
return param.ToTypeScriptType()

View File

@@ -10,6 +10,7 @@ import (
"slices"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
@@ -451,6 +452,72 @@ func TestExecuteWithSuffix(t *testing.T) {
}
}
func TestDateFunctions(t *testing.T) {
t.Run("currentDate", func(t *testing.T) {
tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Today is {{ currentDate }}")
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
t.Fatal(err)
}
expected := "Hello Today is " + time.Now().Format("2006-01-02")
if b.String() != expected {
t.Errorf("got %q, want %q", b.String(), expected)
}
})
t.Run("yesterdayDate", func(t *testing.T) {
tmpl, err := Parse("{{- range .Messages }}{{ .Content }}{{ end }} Yesterday was {{ yesterdayDate }}")
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
t.Fatal(err)
}
expected := "Hello Yesterday was " + time.Now().AddDate(0, 0, -1).Format("2006-01-02")
if b.String() != expected {
t.Errorf("got %q, want %q", b.String(), expected)
}
})
t.Run("yesterdayDate format", func(t *testing.T) {
tmpl, err := Parse("{{- range .Messages }}{{ end }}{{ yesterdayDate }}")
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Messages: []api.Message{{Role: "user", Content: "Hello"}}}); err != nil {
t.Fatal(err)
}
// Verify the format matches YYYY-MM-DD
result := b.String()
if len(result) != 10 {
t.Errorf("expected date length 10, got %d: %q", len(result), result)
}
// Parse and verify it's a valid date
parsed, err := time.Parse("2006-01-02", result)
if err != nil {
t.Errorf("failed to parse date %q: %v", result, err)
}
// Verify it's yesterday
yesterday := time.Now().AddDate(0, 0, -1)
if parsed.Year() != yesterday.Year() || parsed.Month() != yesterday.Month() || parsed.Day() != yesterday.Day() {
t.Errorf("expected yesterday's date, got %v", parsed)
}
})
}
func TestCollate(t *testing.T) {
cases := []struct {
name string