diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index f21a8f50b..41b03e1bd 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -757,3 +757,132 @@ func TestCreateHandler(t *testing.T) { }) } } + +func TestNewCreateRequest(t *testing.T) { + tests := []struct { + name string + from string + opts runOptions + expected *api.CreateRequest + }{ + { + "basic test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "", + Prompt: "You are a fun AI agent", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "parent model test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + }, + }, + { + "parent model as filepath test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "/some/file/like/etc/passwd", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "parent model as windows filepath test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "D:\\some\\file\\like\\etc\\passwd", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "options test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + Options: map[string]any{ + "temperature": 1.0, + }, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + Parameters: map[string]any{ + "temperature": 1.0, + }, + }, + }, + { + "messages test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + System: "You are a fun AI agent", + Messages: []api.Message{ + { + Role: "user", + Content: "hello there!", + }, + { + Role: "assistant", + Content: "hello to you!", + }, + }, + WordWrap: true, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + System: "You are a fun AI agent", + Messages: []api.Message{ + { + Role: "user", + Content: "hello there!", + }, + { + Role: "assistant", + Content: "hello to you!", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := NewCreateRequest(tt.from, tt.opts) + if !cmp.Equal(actual, tt.expected) { + t.Errorf("expected output %#v, got %#v", tt.expected, actual) + } + }) + } +} diff --git a/cmd/interactive.go b/cmd/interactive.go index f3489b652..d85510d45 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -18,6 +18,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/errtypes" + "github.com/ollama/ollama/types/model" ) type MultilineState int @@ -459,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } func NewCreateRequest(name string, opts runOptions) *api.CreateRequest { + parentModel := opts.ParentModel + + modelName := model.ParseName(parentModel) + if !modelName.IsValid() { + parentModel = "" + } + req := &api.CreateRequest{ - Name: name, - From: cmp.Or(opts.ParentModel, opts.Model), + Model: name, + From: cmp.Or(parentModel, opts.Model), } if opts.System != "" { diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index c7b56890e..fbbd9d5ce 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) { DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) } +func TestIntegrationSplitBatch(t *testing.T) { + image, err := base64.StdEncoding.DecodeString(imageEncoding) + require.NoError(t, err) + req := api.GenerateRequest{ + Model: "gemma3:4b", + // Fill up a chunk of the batch so the image will partially spill over into the next one + System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.", + Prompt: "what does the text in this image say?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + Images: []api.ImageData{ + image, + }, + } + + // Note: sometimes it returns "the ollamas" sometimes "the ollams" + resp := "the ollam" + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + // llava models on CPU can be quite slow to start, + DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) +} + const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6 diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b6f20286b..b443fcd3f 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index ec7422244..aad92a5d2 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -41,6 +41,7 @@ enum llm_arch { LLM_ARCH_MINICPM3, LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index ab1a07d10..701830418 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA3: + { + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA3: + { + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: + case LLM_ARCH_GEMMA3: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index 6eb1da08e..d2f3a5108 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // This used to be a regex, but has an extreme cost to compile times. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + // don't quantize vision stuff + quantize &= name.find("v.blk.") == std::string::npos; + + quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; + quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; + quantize &= name.find("v.patch_embedding.weight") == std::string::npos; + quantize &= name.find("v.position_embedding.weight") == std::string::npos; + quantize &= name.find("v.post_layernorm.weight") == std::string::npos; + // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0021-gemma3-quantization.patch b/llama/patches/0021-gemma3-quantization.patch new file mode 100644 index 000000000..4f6dbc11b --- /dev/null +++ b/llama/patches/0021-gemma3-quantization.patch @@ -0,0 +1,113 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Patrick Devine +Date: Fri, 14 Mar 2025 16:33:23 -0700 +Subject: [PATCH] gemma3 quantization + +--- + src/llama-arch.cpp | 19 +++++++++++++++++++ + src/llama-arch.h | 1 + + src/llama-model.cpp | 7 +++++++ + src/llama-quant.cpp | 9 +++++++++ + 4 files changed, 36 insertions(+) + +diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp +index b6f20286..b443fcd3 100644 +--- a/src/llama-arch.cpp ++++ b/src/llama-arch.cpp +@@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, ++ { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, +@@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, ++ { ++ LLM_ARCH_GEMMA3, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, ++ }, ++ }, + { + LLM_ARCH_STARCODER2, + { +diff --git a/src/llama-arch.h b/src/llama-arch.h +index ec742224..aad92a5d 100644 +--- a/src/llama-arch.h ++++ b/src/llama-arch.h +@@ -41,6 +41,7 @@ enum llm_arch { + LLM_ARCH_MINICPM3, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, ++ LLM_ARCH_GEMMA3, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, +diff --git a/src/llama-model.cpp b/src/llama-model.cpp +index ab1a07d1..70183041 100644 +--- a/src/llama-model.cpp ++++ b/src/llama-model.cpp +@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; ++ case LLM_ARCH_GEMMA3: ++ { ++ } break; + case LLM_ARCH_STARCODER2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; ++ case LLM_ARCH_GEMMA3: ++ { ++ } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); +@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { + case LLM_ARCH_PHIMOE: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: ++ case LLM_ARCH_GEMMA3: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: + case LLM_ARCH_GPTNEOX: +diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp +index 6eb1da08..d2f3a510 100644 +--- a/src/llama-quant.cpp ++++ b/src/llama-quant.cpp +@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + ++ // don't quantize vision stuff ++ quantize &= name.find("v.blk.") == std::string::npos; ++ ++ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; ++ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; ++ quantize &= name.find("v.patch_embedding.weight") == std::string::npos; ++ quantize &= name.find("v.position_embedding.weight") == std::string::npos; ++ quantize &= name.find("v.post_layernorm.weight") == std::string::npos; ++ + // quantize only 2D and 3D tensors (experts) + quantize &= (ggml_n_dims(tensor) >= 2); + diff --git a/llm/server.go b/llm/server.go index 537cc1e1a..03baca828 100644 --- a/llm/server.go +++ b/llm/server.go @@ -404,7 +404,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) } - slog.Info("starting llama server", "cmd", s.cmd.String()) + slog.Info("starting llama server", "cmd", s.cmd) if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { @@ -472,7 +472,7 @@ const ( // iota is reset to 0 ServerStatusError ) -func (s ServerStatus) ToString() string { +func (s ServerStatus) String() string { switch s { case ServerStatusReady: return "llm server ready" @@ -487,12 +487,9 @@ func (s ServerStatus) ToString() string { } } -type ServerStatusResp struct { - Status string `json:"status"` - SlotsIdle int `json:"slots_idle"` - SlotsProcessing int `json:"slots_processing"` - Error string `json:"error"` - Progress float32 `json:"progress"` +type ServerStatusResponse struct { + Status ServerStatus `json:"status"` + Progress float32 `json:"progress"` } func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { @@ -504,7 +501,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { } if s.cmd.ProcessState.ExitCode() == -1 { // Most likely a signal killed it, log some more details to try to help troubleshoot - slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) + slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState) } return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -529,21 +526,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { return ServerStatusError, fmt.Errorf("read health request: %w", err) } - var status ServerStatusResp - if err := json.Unmarshal(body, &status); err != nil { + var ssr ServerStatusResponse + if err := json.Unmarshal(body, &ssr); err != nil { return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) } - switch status.Status { - case "ok": - return ServerStatusReady, nil - case "no slot available": - return ServerStatusNoSlotsAvailable, nil - case "loading model": - s.loadProgress = status.Progress - return ServerStatusLoadingModel, nil + switch ssr.Status { + case ServerStatusLoadingModel: + s.loadProgress = ssr.Progress + return ssr.Status, nil + case ServerStatusReady, ServerStatusNoSlotsAvailable: + return ssr.Status, nil default: - return ServerStatusError, fmt.Errorf("server error: %+v", status) + return ssr.Status, fmt.Errorf("server error: %+v", ssr) } } @@ -618,7 +613,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { status, _ := s.getServerStatus(ctx) if lastStatus != status && status != ServerStatusReady { // Only log on status changes - slog.Info("waiting for server to become available", "status", status.ToString()) + slog.Info("waiting for server to become available", "status", status) } switch status { case ServerStatusReady: @@ -632,7 +627,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) stallTimer = time.Now().Add(stallDuration) } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { - slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) + slog.Debug("model load completed, waiting for server to become available", "status", status) stallTimer = time.Now().Add(stallDuration) fullyLoaded = true } @@ -673,63 +668,26 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } -type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` - - Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` - } -} - type CompletionRequest struct { Prompt string Format json.RawMessage Images []ImageData Options *api.Options + + Grammar string // set before sending the request to the subprocess } type CompletionResponse struct { - Content string - DoneReason string - Done bool - PromptEvalCount int - PromptEvalDuration time.Duration - EvalCount int - EvalDuration time.Duration + Content string `json:"content"` + DoneReason string `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - request := map[string]any{ - "prompt": req.Prompt, - "stream": true, - "n_predict": req.Options.NumPredict, - "n_keep": req.Options.NumKeep, - "main_gpu": req.Options.MainGPU, - "temperature": req.Options.Temperature, - "top_k": req.Options.TopK, - "top_p": req.Options.TopP, - "min_p": req.Options.MinP, - "typical_p": req.Options.TypicalP, - "repeat_last_n": req.Options.RepeatLastN, - "repeat_penalty": req.Options.RepeatPenalty, - "presence_penalty": req.Options.PresencePenalty, - "frequency_penalty": req.Options.FrequencyPenalty, - "mirostat": req.Options.Mirostat, - "mirostat_tau": req.Options.MirostatTau, - "mirostat_eta": req.Options.MirostatEta, - "seed": req.Options.Seed, - "stop": req.Options.Stop, - "image_data": req.Images, - "cache_prompt": true, - } - if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -737,7 +695,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu // these as "not set". break case `"json"`: - request["grammar"] = grammarJSON + req.Grammar = grammarJSON default: if req.Format[0] != '{' { return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) @@ -748,10 +706,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if g == nil { return fmt.Errorf("invalid JSON schema in format") } - request["grammar"] = string(g) + req.Grammar = string(g) } } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") @@ -772,7 +735,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %s", status.ToString()) + return fmt.Errorf("unexpected server status: %s", status) } // Handling JSON marshaling with special characters unescaped. @@ -780,7 +743,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu enc := json.NewEncoder(buffer) enc.SetEscapeHTML(false) - if err := enc.Encode(request); err != nil { + if err := enc.Encode(req); err != nil { return fmt.Errorf("failed to marshal data: %v", err) } @@ -831,7 +794,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu evt = line } - var c completion + var c CompletionResponse if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } @@ -855,20 +818,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu }) } - if c.Stop { - doneReason := "stop" - if c.StoppedLimit { - doneReason = "length" - } - - fn(CompletionResponse{ - Done: true, - DoneReason: doneReason, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) + if c.Done { + fn(c) return nil } } @@ -916,7 +867,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + return nil, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) @@ -1061,12 +1012,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { } return 0 } - -func parseDurationMs(ms float64) time.Duration { - dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } - - return dur -} diff --git a/model/input/input.go b/model/input/input.go index 0cb3f3f41..30bdcf065 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -15,6 +15,12 @@ type Input struct { // stored in Multimodal, used for caching and comparing // equality. MultimodalHash uint64 + + // SameBatch forces the following number of tokens to be processed + // in a single batch, breaking and extending batches as needed. + // Useful for things like images that must be processed in one + // shot. + SameBatch int } // MultimodalIndex is a multimodal element (such as an image) diff --git a/model/model.go b/model/model.go index fadea3246..53e47add9 100644 --- a/model/model.go +++ b/model/model.go @@ -60,7 +60,7 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []input.Input) ([]input.Input, error) + PostTokenize([]input.Input) ([]input.Input, error) } // Base implements the common fields and methods for all models diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 24193f15f..32ad80f43 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -2,10 +2,9 @@ package gemma3 import ( "bytes" - "encoding/binary" - "hash/fnv" "image" "math" + "slices" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } -type imageToken struct { - embedding ml.Tensor - index int -} - -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input - fnvHash := fnv.New64a() for _, inp := range inputs { if inp.Multimodal == nil { result = append(result, inp) } else { - imageInputs := []input.Input{ - {Token: 108}, // "\n\n" - {Token: 255999}, // """ - } - result = append(result, imageInputs...) - - // add image embeddings inputMultimodal := inp.Multimodal.(ml.Tensor) - for i := range inputMultimodal.Dim(1) { - fnvHash.Reset() - binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash) - fnvHash.Write([]byte{byte(i)}) + result = append(result, + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 255999}, // """ + input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + ) - imageToken := imageToken{embedding: inputMultimodal, index: i} - result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()}) - } + // add image token placeholders + result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, input.Input{Token: 256000}, // diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7a88c0921..567f65a5e 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int { - var embedding ml.Tensor - var src, dst, length int - var except []int - - for _, image := range multimodal { - imageToken := image.Multimodal.(imageToken) - imageSrc := imageToken.index - imageDst := image.Index - - if embedding == nil { - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst { - src = imageSrc - dst = imageDst - length++ - } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst { - length++ - } else { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } - - except = append(except, imageDst) - } - - if embedding != nil { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - } - - return except -} - func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal) + // set image embeddings + var except []int + for _, image := range opts.Multimodal { + visionOutputs := image.Multimodal.(ml.Tensor) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + + for i := range visionOutputs.Dim(1) { + except = append(except, image.Index+i) + } + } for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 071d77ac7..fa4d570ca 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var images []input.Input fnvHash := fnv.New64a() for i := range inputs { if inputs[i].Multimodal == nil { if len(images) > 0 { - inputs[i].Multimodal = images[0].Multimodal + inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} inputs[i].MultimodalHash = images[0].MultimodalHash for j := 1; j < len(images); j++ { - inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) + inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) fnvHash.Reset() binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) @@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor if len(opts.Multimodal) > 0 { - crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) + images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) + if len(images) > 0 { + crossAttentionStates = images[len(images)-1] + } } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 8662afc1e..83802d604 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/runner/common" ) @@ -99,7 +100,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // generating image embeddings for each image -func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) { var inputs []input var parts []string var matches [][]string @@ -229,7 +230,7 @@ type Server struct { image *ImageContext // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var samplingParams llama.SamplingParams - samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar + // Extract options from the CompletionRequest + samplingParams := llama.SamplingParams{ + TopK: req.Options.TopK, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TypicalP: req.Options.TypicalP, + Temp: req.Options.Temperature, + RepeatLastN: req.Options.RepeatLastN, + PenaltyRepeat: req.Options.RepeatPenalty, + PenaltyFreq: req.Options.FrequencyPenalty, + PenaltyPresent: req.Options.PresencePenalty, + Mirostat: req.Options.Mirostat, + MirostatTau: req.Options.MirostatTau, + MirostatEta: req.Options.MirostatEta, + Seed: uint32(req.Options.Seed), + Grammar: req.Grammar, + } seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: req.NumKeep, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, }) @@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numDecoded, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numDecoded, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - var req EmbeddingRequest + var req llm.EmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) return @@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding - if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ Embedding: embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } } -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -879,7 +794,7 @@ func (s *Server) loadModel( panic(err) } - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -937,7 +852,7 @@ func Execute(args []string) error { parallel: *parallel, seqs: make([]*Sequence, *parallel), seqsSem: semaphore.NewWeighted(int64(*parallel)), - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } var tensorSplitFloats []float32 diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index a411fddb1..adcb3f738 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } + // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? if !cachePrompt { numPast = 0 } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c380ef221..d4c24556c 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -33,10 +34,14 @@ import ( _ "github.com/ollama/ollama/model/models" ) +type contextList struct { + list []ml.Context +} + type Sequence struct { - // ctx for allocating tensors that last the lifetime of the sequence, such as + // ctxs are used for allocating tensors that last the lifetime of the sequence, such as // multimodal embeddings - ctx ml.Context + ctxs *contextList // batch index iBatch int @@ -94,13 +99,12 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() - ctx := s.model.Backend().NewContext() - inputs, err := s.inputs(ctx, prompt, images) + inputs, ctxs, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -126,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // TODO(jessegross): Ingest cached history for grammar return &Sequence{ - ctx: ctx, + ctxs: ctxs, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -145,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) { var inputs []input.Input var parts []string var matches [][]string @@ -160,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in parts = []string{prompt} } + var contexts contextList + runtime.AddCleanup(&contexts, func(ctxs []ml.Context) { + for _, ctx := range ctxs { + ctx.Close() + } + }, contexts.list) + postTokenize := false for i, part := range parts { // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { - return nil, err + return nil, nil, err } for _, t := range tokens { @@ -185,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in } if imageIndex < 0 { - return nil, fmt.Errorf("invalid image index: %d", n) + return nil, nil, fmt.Errorf("invalid image index: %d", n) } + ctx := s.model.Backend().NewContext() + contexts.list = append(contexts.list, ctx) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { - return nil, err + return nil, nil, err } s.multimodalHash.Reset() @@ -204,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]in if visionModel && postTokenize { var err error - inputs, err = multimodalProcessor.PostTokenize(ctx, inputs) + inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { - return nil, err + return nil, nil, err } } - return inputs, nil + return inputs, &contexts, nil } type Server struct { @@ -222,7 +235,7 @@ type Server struct { model model.Model // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -305,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.responses) close(seq.embedding) seq.cache.InUse = false - seq.ctx.Close() s.seqs[seqIndex] = nil s.seqsSem.Release(1) } @@ -351,6 +363,8 @@ func (s *Server) processBatch() error { seq.cache.Inputs = []input.Input{} } + batchSize := s.batchSize + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { @@ -363,7 +377,15 @@ func (s *Server) processBatch() error { } } - if j >= s.batchSize { + // If we are required to put following inputs into a single batch then extend the + // batch size. Since we are only extending the size the minimum amount possible, this + // will cause a break if we have pending inputs. + minBatch := 1 + inp.SameBatch + if minBatch > batchSize { + batchSize = minBatch + } + + if len(seq.pendingInputs)+minBatch > batchSize { break } @@ -501,75 +523,18 @@ func (s *Server) processBatch() error { return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -591,18 +556,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } sampler := sample.NewSampler( - req.Temperature, - req.TopK, - req.TopP, - req.MinP, - req.Seed, + req.Options.Temperature, + req.Options.TopK, + req.Options.TopP, + req.Options.MinP, + req.Options.Seed, grammar, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: int32(req.NumKeep), + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, }) @@ -625,7 +590,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -652,7 +617,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -663,15 +628,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numPredicted, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numPredicted, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -682,43 +649,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -772,7 +706,7 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -824,7 +758,7 @@ func Execute(args []string) error { server := &Server{ batchSize: *batchSize, - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } // TODO(jessegross): Parameters that need to be implemented: diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 76d0a6c2b..616e8501c 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -8,7 +8,7 @@ usage() { exit 1 } -export VERSION=${VERSION:-$(git describe --tags --dirty)} +export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")} export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" export CGO_CPPFLAGS='-mmacosx-version-min=11.3' diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index cf05f79ae..d1d01ba46 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -25,6 +25,7 @@ import ( "os" "path/filepath" "runtime" + "runtime/debug" "slices" "strconv" "strings" @@ -259,6 +260,7 @@ func DefaultRegistry() (*Registry, error) { } var rc Registry + rc.UserAgent = UserAgent() rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) if err != nil { return nil, err @@ -274,6 +276,16 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } +func UserAgent() string { + buildinfo, _ := debug.ReadBuildInfo() + return fmt.Sprintf("ollama/%s (%s %s) Go/%s", + buildinfo.Main.Version, + runtime.GOARCH, + runtime.GOOS, + runtime.Version(), + ) +} + func (r *Registry) maxStreams() int { return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) } diff --git a/server/prompt.go b/server/prompt.go index d053f2a8d..5b5b958f1 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var system []api.Message isMllama := checkMllamaModelFamily(m) - isGemma3 := checkGemma3ModelFamily(m) var imageNumTokens int // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n; i >= 0; i-- { - if (isMllama || isGemma3) && len(msgs[i].Images) > 1 { + if isMllama && len(msgs[i].Images) > 1 { return "", nil, errTooManyImages } @@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool { } return false } - -func checkGemma3ModelFamily(m *Model) bool { - for _, arch := range m.Config.ModelFamilies { - if arch == "gemma3" { - return true - } - } - return false -}