Compare commits
1 Commits
parth/move
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19638cec55 |
@@ -411,8 +411,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -539,8 +537,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
|
||||
### Mobile
|
||||
|
||||
|
||||
21
api/types.go
21
api/types.go
@@ -90,10 +90,6 @@ type GenerateRequest struct {
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -124,10 +120,6 @@ type ChatRequest struct {
|
||||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -316,19 +308,6 @@ type ChatResponse struct {
|
||||
Metrics
|
||||
}
|
||||
|
||||
// DebugInfo contains debug information for template rendering
|
||||
type DebugInfo struct {
|
||||
RenderedTemplate string `json:"rendered_template"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
||||
type DebugTemplateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DebugInfo DebugInfo `json:"_debug_info"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
|
||||
75
docs/docs.json
Normal file
75
docs/docs.json
Normal file
@@ -0,0 +1,75 @@
|
||||
{
|
||||
"$schema": "https://mintlify.com/docs.json",
|
||||
"theme": "mint",
|
||||
"background": {
|
||||
"color": {
|
||||
"light": "#ffffff",
|
||||
"dark": "#000000"
|
||||
}
|
||||
},
|
||||
"appearance": {
|
||||
"default": "light"
|
||||
},
|
||||
"styling": {
|
||||
"codeblocks": "system"
|
||||
},
|
||||
"contextual": {
|
||||
"options": ["copy", "chatgpt", "claude", "view"]
|
||||
},
|
||||
"fonts": {
|
||||
"heading": {
|
||||
"family": "Inter"
|
||||
},
|
||||
"body": {
|
||||
"family": "Inter"
|
||||
}
|
||||
},
|
||||
"name": "Ollama",
|
||||
"colors": {
|
||||
"primary": "#000",
|
||||
"light": "#b5b5b5",
|
||||
"dark": "#fff"
|
||||
},
|
||||
"favicon": "/ollama.png",
|
||||
"logo": {
|
||||
"light": "/ollama.png",
|
||||
"dark": "/favicon.svg"
|
||||
},
|
||||
"navigation": {
|
||||
"tabs": [
|
||||
{
|
||||
"tab": "Documentation",
|
||||
"groups": [
|
||||
{
|
||||
"group": "Home",
|
||||
"pages": ["index", "quickstart", "faq", "troubleshooting"]
|
||||
},
|
||||
{
|
||||
"group": "Platforms",
|
||||
"pages": ["linux", "windows", "docker"]
|
||||
},
|
||||
{
|
||||
"group": "Features",
|
||||
"pages": [
|
||||
"modelfile",
|
||||
"apis",
|
||||
"openai",
|
||||
"import",
|
||||
"gpu",
|
||||
"benchmark"
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"tab": "Development",
|
||||
"groups": [
|
||||
{
|
||||
"group": " ",
|
||||
"pages": ["development", "examples", "template"]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,10 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
@@ -111,56 +111,5 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
@@ -19,8 +19,6 @@ import (
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
|
||||
@@ -567,76 +567,6 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
|
||||
3
llama/llama.cpp/src/llama-context.cpp
vendored
3
llama/llama.cpp/src/llama-context.cpp
vendored
@@ -962,7 +962,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const bool output_all = false;
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
|
||||
@@ -13,7 +13,7 @@ checks.
|
||||
1 file changed, 18 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 57eae461..c7f9dc3a 100644
|
||||
index 57eae461..9db0c8b5 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2671,12 +2671,24 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <git@mxy.ng>
|
||||
Date: Mon, 18 Aug 2025 16:58:39 -0700
|
||||
Subject: [PATCH] decode: disable output_all
|
||||
|
||||
---
|
||||
src/llama-context.cpp | 3 +--
|
||||
1 file changed, 1 insertion(+), 2 deletions(-)
|
||||
|
||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||
index 26a5cf9c..6ece5263 100644
|
||||
--- a/src/llama-context.cpp
|
||||
+++ b/src/llama-context.cpp
|
||||
@@ -962,8 +962,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
- // when computing embeddings, all tokens are output
|
||||
- const bool output_all = cparams.embeddings;
|
||||
+ const bool output_all = false;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -652,9 +651,7 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
|
||||
if !success {
|
||||
s.initModel(ctx, LoadRequest{}, LoadOperationClose)
|
||||
}
|
||||
if s.mem != nil {
|
||||
s.mem.Log(slog.LevelInfo)
|
||||
}
|
||||
s.mem.Log(slog.LevelInfo)
|
||||
}()
|
||||
|
||||
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||
@@ -1332,9 +1329,7 @@ type CompletionRequest struct {
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
FunctionNameMap *harmony.FunctionNameMap
|
||||
PrefillContent *bool
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1361,15 +1356,13 @@ func (d DoneReason) String() string {
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
@@ -1487,7 +1480,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
switch {
|
||||
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
@@ -1500,14 +1493,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||
fn(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -400,8 +400,6 @@ type Tensor interface {
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
BackendSetFromIntSlice(s []int32)
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
@@ -82,7 +82,6 @@ type Backend struct {
|
||||
// to the name that is used by the model definition
|
||||
tensorLoadTargets map[string][]string
|
||||
|
||||
schedMu sync.Mutex // Only one Compute can run at a time
|
||||
sched C.ggml_backend_sched_t
|
||||
schedBackends []C.ggml_backend_t
|
||||
schedBufts []C.ggml_backend_buffer_type_t
|
||||
@@ -770,8 +769,6 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
c.b.schedMu.Lock()
|
||||
defer c.b.schedMu.Unlock()
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
@@ -1040,12 +1037,6 @@ func (t *Tensor) Floats() (data []float32) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
|
||||
@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
@@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, nil, errors.New("batch size cannot be less than 1")
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
@@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
|
||||
return batch.Inputs, t, nil
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,16 +134,16 @@ type separator struct {
|
||||
y bool
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
var imageInputs []*input.Input
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||
var imageInputs []input.Input
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
patchesPerChunk := mm.Tensor.Dim(1)
|
||||
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
if i < len(inp.Multimodal)-1 {
|
||||
separator := mm.Data.(*separator)
|
||||
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
|
||||
if separator.x {
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
}
|
||||
if separator.y {
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
}
|
||||
} else {
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||
// that can be processed together.
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
for i, row := range inp.Multimodal {
|
||||
// [IMG]
|
||||
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
if i == len(inp.Multimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, &input.Input{Token: 13})
|
||||
result = append(result, input.Input{Token: 13})
|
||||
} else {
|
||||
// [IMG_BREAK]
|
||||
result = append(result, &input.Input{Token: 12})
|
||||
result = append(result, input.Input{Token: 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal != nil {
|
||||
inputs[i].Token = 128256 // <|image|>
|
||||
|
||||
@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, &input.Input{Token: pre[i]})
|
||||
result = append(result, input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, &input.Input{Token: visionStartToken})
|
||||
result = append(result, input.Input{Token: visionStartToken})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, &input.Input{
|
||||
result = append(result, input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, &input.Input{Token: visionEndToken})
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ type InputCacheSlot struct {
|
||||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []*input.Input
|
||||
Inputs []input.Input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlo
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot,
|
||||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||
oldestSlot.Inputs = make([]input.Input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot,
|
||||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []*input.Input
|
||||
Inputs []input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||
slot.Inputs = []*input.Input{}
|
||||
slot.Inputs = []input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
|
||||
@@ -13,50 +13,50 @@ import (
|
||||
func TestCountCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []*input.Input
|
||||
t2 []*input.Input
|
||||
t1 []input.Input
|
||||
t2 []input.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []*input.Input{{Token: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []input.Input{{Token: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []*input.Input{{MultimodalHash: 1}},
|
||||
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
t1: []input.Input{{MultimodalHash: 1}},
|
||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{},
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []*input.Input
|
||||
prompt []input.Input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []*input.Input{{Token: 2}},
|
||||
prompt: []input.Input{{Token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []*input.Input
|
||||
prompt []input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []*input.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []*input.Input
|
||||
inputs []input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -29,7 +28,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -53,10 +51,10 @@ type Sequence struct {
|
||||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []*input.Input
|
||||
inputs []input.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []*input.Input
|
||||
pendingInputs []input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
@@ -88,12 +86,6 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if the sequence if finished and marked for removal on next pass
|
||||
finished bool
|
||||
|
||||
// True if we have to skip this sequence to shift the cache
|
||||
skipForShift bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -190,8 +182,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []*input.Input
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
@@ -218,7 +210,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, &input.Input{Token: t})
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
@@ -251,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
@@ -267,27 +259,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
|
||||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type batchState struct {
|
||||
id int
|
||||
ctx ml.Context
|
||||
modelInput ml.Tensor
|
||||
modelOutput ml.Tensor
|
||||
batchInputs []*input.Input
|
||||
batch input.Batch
|
||||
seqs []*Sequence // full set of seqs at the time this batch was initiated
|
||||
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
|
||||
|
||||
// Signaled when this batches inputs are ready and compute can proceed
|
||||
inputsReadyCh chan struct{}
|
||||
|
||||
// Signaling when Compute is about to begin on this batch, and
|
||||
// seqs have been updated to prepare for the next batch
|
||||
computeStartedCh chan struct{}
|
||||
|
||||
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||
outputsReadyCh chan struct{}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
@@ -319,16 +290,6 @@ type Server struct {
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// A prior batch that's still being processed
|
||||
// only read or written by forwardBatch
|
||||
pendingBatch *batchState
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
// protects access to everything below this line
|
||||
// this is context state needed for decoding
|
||||
mu sync.Mutex
|
||||
@@ -389,132 +350,45 @@ func flushPending(seq *Sequence) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
// finish could be called multiple times since we prepare 1 batch ahead
|
||||
// and multiple scenarios can lead to finishing a sequence
|
||||
// ensure only the first finish called is processed
|
||||
if seq.finished {
|
||||
return
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
seq.finished = true
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
func (s *Server) removeFinishedSequence(seqIndex int) {
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
var bs *batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
var err error
|
||||
bs, err = s.forwardBatch()
|
||||
err := s.processBatch()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if bs == nil {
|
||||
continue
|
||||
}
|
||||
go s.computeBatch(bs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forwardBatch will calculate a batch.
|
||||
func (s *Server) forwardBatch() (*batchState, error) {
|
||||
inputsReady := false
|
||||
var inputsReadyCh chan struct{}
|
||||
|
||||
// If we have a pending batch still processing, wait until Compute has started
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if s.pendingBatch != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
|
||||
<-s.pendingBatch.computeStartedCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
|
||||
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
inputsReadyCh = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
func (s *Server) processBatch() error {
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// If new sequences have been added with an active batch we delay preparing the next batch
|
||||
// until Compute has finished
|
||||
if s.pendingBatch != nil {
|
||||
for seqIdx := range s.seqs {
|
||||
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
|
||||
s.mu.Unlock() // release the lock so computeBatch can finish up
|
||||
<-s.pendingBatch.outputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
|
||||
s.mu.Lock()
|
||||
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clear pending Batch - we'll set it if we have a batch with any inputs
|
||||
s.pendingBatch = nil
|
||||
|
||||
// Remove any finished sequences before recording the active set of seqs in the batch
|
||||
for seqIdx := range s.seqs {
|
||||
seq := s.seqs[seqIdx]
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
if seq.finished {
|
||||
s.removeFinishedSequence(seqIdx)
|
||||
continue
|
||||
}
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.finishSequence(seqIdx, llm.DoneReasonLength)
|
||||
s.removeFinishedSequence(seqIdx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// next batch
|
||||
nb := &batchState{
|
||||
id: s.batchID,
|
||||
initSeqIdx: s.nextSeq - 1,
|
||||
seqs: make([]*Sequence, len(s.seqs)),
|
||||
inputsReadyCh: inputsReadyCh,
|
||||
computeStartedCh: make(chan struct{}, 1),
|
||||
outputsReadyCh: make(chan struct{}, 1),
|
||||
}
|
||||
ctx := s.model.Backend().NewContext()
|
||||
nb.ctx = ctx
|
||||
defer ctx.Close()
|
||||
|
||||
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
|
||||
copy(nb.seqs, s.seqs)
|
||||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -522,13 +396,20 @@ func (s *Server) forwardBatch() (*batchState, error) {
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
@@ -568,21 +449,18 @@ func (s *Server) forwardBatch() (*batchState, error) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
seq.skipForShift = true // cleared in computeBatch below for the next batch
|
||||
continue
|
||||
} else {
|
||||
ctx.Close()
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchInputs = append(batchInputs, seq.inputs[i])
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
@@ -590,13 +468,10 @@ func (s *Server) forwardBatch() (*batchState, error) {
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
// TODO BUG HERE!!!
|
||||
// Somehow sometimes iBatch isn't set correctly
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
@@ -610,138 +485,36 @@ func (s *Server) forwardBatch() (*batchState, error) {
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
ctx.Close()
|
||||
return nil, nil
|
||||
return nil
|
||||
}
|
||||
s.batchID++
|
||||
|
||||
var err error
|
||||
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
|
||||
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
return nil, fmt.Errorf("failed to build graph: %w", err)
|
||||
}
|
||||
nb.batchInputs = batchInputs
|
||||
nb.batch = batch
|
||||
|
||||
// computeBatch will close the context in the batch upon completion
|
||||
s.pendingBatch = nb
|
||||
|
||||
if inputsReady {
|
||||
nb.inputsReadyCh <- struct{}{}
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
return nb, nil
|
||||
}
|
||||
logits := modelOutput.Floats()
|
||||
|
||||
// Async processing of the next batch
|
||||
func (s *Server) computeBatch(bs *batchState) {
|
||||
if bs == nil || bs.ctx == nil {
|
||||
// Nothing to compute
|
||||
return
|
||||
}
|
||||
defer bs.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
|
||||
<-bs.inputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
|
||||
bs.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Gather the actual input token values now that they're ready
|
||||
batchInputs := make([]int32, len(bs.batchInputs))
|
||||
for i := range batchInputs {
|
||||
batchInputs[i] = bs.batchInputs[i].Token
|
||||
}
|
||||
|
||||
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
|
||||
|
||||
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||
// decoded tokens.
|
||||
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
|
||||
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||
for i, seq := range s.seqs {
|
||||
iBatches[i] = -1
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added sequences
|
||||
if bs.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
seq.pendingInputs = []input.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
return
|
||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
||||
}
|
||||
// Record so we can skip during Decode
|
||||
promptProcessing[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
iBatches[i] = seq.iBatch
|
||||
}
|
||||
|
||||
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||
s.mu.Unlock()
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
|
||||
bs.computeStartedCh <- struct{}{}
|
||||
|
||||
bs.modelInput.BackendSetFromIntSlice(batchInputs)
|
||||
bs.ctx.Compute(bs.modelOutput)
|
||||
logits := bs.modelOutput.Floats()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added sequences
|
||||
if bs.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != bs.seqs[i] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if promptProcessing[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
}
|
||||
@@ -749,46 +522,35 @@ func (s *Server) computeBatch(bs *batchState) {
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(bs.batch.Outputs)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
|
||||
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
}
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
}
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
|
||||
if nextBatchTokens[i] == nil {
|
||||
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
|
||||
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
|
||||
return
|
||||
}
|
||||
|
||||
// fill in the final selected token value to replace the placeholder in the next batch
|
||||
// nextBatchTokensWritten++
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
@@ -813,10 +575,9 @@ func (s *Server) computeBatch(bs *batchState) {
|
||||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -829,9 +590,11 @@ func (s *Server) computeBatch(bs *batchState) {
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.finishSequence(i, llm.DoneReasonConnectionClosed)
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -841,15 +604,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
if req.FunctionNameMap != nil {
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.FunctionNameMap = req.FunctionNameMap
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
@@ -940,16 +694,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
var thinking string
|
||||
if harmonyMessageHandler != nil {
|
||||
var toolContent string
|
||||
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
||||
harmonyToolParser.Add(toolContent)
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
Thinking: thinking,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
@@ -958,29 +704,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
var toolCalls []api.ToolCall
|
||||
if harmonyMessageHandler != nil {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
ToolCalls: toolCalls,
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
@@ -1012,10 +736,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]*input.Input, s.batchSize)
|
||||
for i := range inputs {
|
||||
inputs[i] = &input.Input{}
|
||||
}
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
@@ -1057,11 +778,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]*input.Input, s.batchSize)
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
for i := len(inputs); i < s.batchSize; i++ {
|
||||
newInputs[i] = &input.Input{}
|
||||
}
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
@@ -1124,7 +842,6 @@ func (s *Server) allocModel(
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
debug.PrintStack()
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
} else {
|
||||
@@ -1294,7 +1011,6 @@ func Execute(args []string) error {
|
||||
server := &Server{
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@@ -1,39 +1,36 @@
|
||||
package harmony
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
@@ -91,18 +88,17 @@ func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
|
||||
if prefillContentOrThinking != nil {
|
||||
if *prefillContentOrThinking {
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else {
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
@@ -279,21 +275,19 @@ const (
|
||||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
state harmonyMessageState
|
||||
harmonyParser *HarmonyParser
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
HarmonyParser: &HarmonyParser{
|
||||
harmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,7 +298,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.HarmonyParser.AddContent(content)
|
||||
events := h.harmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
@@ -384,129 +378,3 @@ func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
||||
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||
return a.acc.String()
|
||||
}
|
||||
|
||||
// FunctionNameMap maps a user-specified function name to a valid function
|
||||
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||
// transform user-specified function names, which might contain characters that
|
||||
// are not allowed in TypeScript identifiers
|
||||
type FunctionNameMap struct {
|
||||
userToHarmony map[string]string
|
||||
harmonyToUser map[string]string
|
||||
}
|
||||
|
||||
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
|
||||
// necessary to avoid exposing map internals
|
||||
type alias struct {
|
||||
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||
}
|
||||
return json.Marshal(alias{
|
||||
UserToHarmony: m.userToHarmony,
|
||||
HarmonyToUser: m.harmonyToUser,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
|
||||
type alias struct {
|
||||
UserToHarmony map[string]string `json:"userToHarmony"`
|
||||
HarmonyToUser map[string]string `json:"harmonyToUser"`
|
||||
}
|
||||
var a alias
|
||||
if err := json.Unmarshal(b, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.userToHarmony == nil {
|
||||
m.userToHarmony = make(map[string]string)
|
||||
}
|
||||
if m.harmonyToUser == nil {
|
||||
m.harmonyToUser = make(map[string]string)
|
||||
}
|
||||
maps.Copy(m.userToHarmony, a.UserToHarmony)
|
||||
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFunctionNameMap() *FunctionNameMap {
|
||||
return &FunctionNameMap{
|
||||
userToHarmony: make(map[string]string),
|
||||
harmonyToUser: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||
// the conversion process is not reversible without the appropriate state
|
||||
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||
return userFunctionName
|
||||
}
|
||||
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||
// fallback to the original function name if we can't find a mapping
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// convertToValidChars converts a user-specified function name to a valid
|
||||
// TypeScript identifier.
|
||||
//
|
||||
// Limitations:
|
||||
//
|
||||
// - This doesn't restrict reserved TypeScript keywords.
|
||||
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||
// identifiers these models were trained on, so in the end we might want to
|
||||
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||
mapper := func(r rune) rune {
|
||||
// first, replace certain characters with underscores
|
||||
if r == ' ' || r == '-' || r == '.' {
|
||||
return '_'
|
||||
}
|
||||
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
return r
|
||||
}
|
||||
|
||||
// finally, remove any other characters
|
||||
return -1
|
||||
}
|
||||
candidate := strings.Map(mapper, userFunctionName)
|
||||
|
||||
// set a default name if we end up with nothing left
|
||||
if candidate == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
|
||||
// if the candidate starts with a number, prepend an underscore to make it a
|
||||
// valid identifier
|
||||
if unicode.IsDigit(rune(candidate[0])) {
|
||||
candidate = "_" + candidate
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||
candidate := originalCandidate
|
||||
|
||||
// Check for dupes, and if so, add a number to the end.
|
||||
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||
count := 2
|
||||
for {
|
||||
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||
break
|
||||
}
|
||||
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||
count++
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package harmony
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -467,71 +467,3 @@ func TestHarmonyParserStreaming(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||
// handle any saving (and therefore no dupe handling)
|
||||
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||
{name: "leading number", in: "123", want: "_123"},
|
||||
{name: "$ allowed", in: "$", want: "$"},
|
||||
// show that we allow weird unicode letter characters, though we might want
|
||||
// to convert them to their closest ASCII equivalents in the future
|
||||
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewFunctionNameMap()
|
||||
got := parser.convertToValidChars(tt.in)
|
||||
if got != tt.want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
parser := NewFunctionNameMap()
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for j, in := range tt.in {
|
||||
got := parser.ConvertAndAdd(in)
|
||||
want := tt.want[j]
|
||||
if got != want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
// check that the maps are correct
|
||||
if parser.userToHarmony[in] != want {
|
||||
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||
}
|
||||
if parser.harmonyToUser[want] != in {
|
||||
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
143
server/routes.go
143
server/routes.go
@@ -32,7 +32,6 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/openai"
|
||||
@@ -195,7 +194,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||
useHarmony := shouldUseHarmony(*m) && !req.Raw
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStart()
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
@@ -308,19 +314,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
prompt = b.String()
|
||||
}
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
if !useHarmony {
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
@@ -356,7 +349,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if !useHarmony && thinkingState != nil {
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
||||
res.Response = content
|
||||
res.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
} else if thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
@@ -367,6 +365,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
if useHarmony {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
|
||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
@@ -1572,58 +1590,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||
|
||||
processedTools := req.Tools
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
var prefillContentOrThinking *bool
|
||||
if useHarmony {
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
|
||||
// prefill content or thinking flag if the last message is an assistant message
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
if lastMessage.Content != "" {
|
||||
trueVal := true
|
||||
// true sets content to be prefilled
|
||||
prefillContentOrThinking = &trueVal
|
||||
} else if lastMessage.Thinking != "" {
|
||||
// false sets thinking to be prefilled
|
||||
falseVal := false
|
||||
prefillContentOrThinking = &falseVal
|
||||
}
|
||||
}
|
||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||
// renamed to be valid Harmony function names.
|
||||
processedTools = make([]api.Tool, len(req.Tools))
|
||||
copy(processedTools, req.Tools)
|
||||
for i, tool := range processedTools {
|
||||
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
useHarmony := shouldUseHarmony(*m)
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
@@ -1631,6 +1605,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||
@@ -1650,17 +1637,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
FunctionNameMap: functionNameMap,
|
||||
PrefillContent: prefillContentOrThinking,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
@@ -1676,10 +1661,30 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
|
||||
if r.Done {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
||||
}
|
||||
}
|
||||
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,413 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request api.GenerateRequest
|
||||
expectDebug bool
|
||||
expectTemplate string
|
||||
expectNumImages int
|
||||
}{
|
||||
{
|
||||
name: "debug render only enabled",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello, world!",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "debug render only disabled",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello, world!",
|
||||
DebugRenderOnly: false,
|
||||
},
|
||||
expectDebug: false,
|
||||
},
|
||||
{
|
||||
name: "debug render only with system prompt",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "User question",
|
||||
System: "You are a helpful assistant",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "User question",
|
||||
},
|
||||
{
|
||||
name: "debug render only with template",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Template: "PROMPT: {{ .Prompt }}",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "PROMPT: Hello",
|
||||
},
|
||||
{
|
||||
name: "debug render only with images",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{[]byte("fake-image-data")},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[img-0]\n\nDescribe this image",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
name: "debug render only with raw mode",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Raw prompt text",
|
||||
Raw: true,
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "Raw prompt text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Test both with and without streaming
|
||||
streamValues := []bool{false, true}
|
||||
for _, stream := range streamValues {
|
||||
streamSuffix := ""
|
||||
if stream {
|
||||
streamSuffix = " (streaming)"
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
w := createRequest(t, s.GenerateHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Model != tt.request.Model {
|
||||
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
|
||||
}
|
||||
|
||||
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
|
||||
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatDebugRenderOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request api.ChatRequest
|
||||
expectDebug bool
|
||||
expectTemplate string
|
||||
expectNumImages int
|
||||
}{
|
||||
{
|
||||
name: "chat debug render only enabled",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "system: You are a helpful assistant\nuser: Hello\n",
|
||||
},
|
||||
{
|
||||
name: "chat debug render only disabled",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
DebugRenderOnly: false,
|
||||
},
|
||||
expectDebug: false,
|
||||
},
|
||||
{
|
||||
name: "chat debug with assistant message",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "user: Hello\nassistant: Hi there!\nuser: How are you?\n",
|
||||
},
|
||||
{
|
||||
name: "chat debug with images",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's in this image?",
|
||||
Images: []api.ImageData{[]byte("fake-image-data")},
|
||||
},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "user: [img-0]What's in this image?\n",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
name: "chat debug with tools",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Get the weather"},
|
||||
},
|
||||
Tools: api.Tools{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
},
|
||||
},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\",\"required\":null,\"properties\":null}}}]user: Get the weather\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Test both with and without streaming
|
||||
streamValues := []bool{false, true}
|
||||
for _, stream := range streamValues {
|
||||
streamSuffix := ""
|
||||
if stream {
|
||||
streamSuffix = " (streaming)"
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
w := createRequest(t, s.ChatHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Model != tt.request.Model {
|
||||
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
|
||||
}
|
||||
|
||||
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
|
||||
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user