From 3892c3a7032c99db250c3266276c4525d243950a Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 14 Mar 2025 15:21:53 -0700 Subject: [PATCH 1/8] llm: remove internal subprocess req and resp types (#9324) This commit refactors the LLM subsystem by removing internal subprocess request and response types. It consolidates duplicate type definitions across the codebase, moving them to centralized locations. The change also standardizes interfaces between components, simplifies the ServerStatusResp struct, and moves the ParseDurationMs function to a common package. This cleanup reduces code duplication between different runner implementations (llamarunner and ollamarunner). --- llm/server.go | 136 +++++++------------------ runner/llamarunner/runner.go | 185 +++++++++------------------------- runner/ollamarunner/cache.go | 1 + runner/ollamarunner/runner.go | 157 +++++++---------------------- 4 files changed, 125 insertions(+), 354 deletions(-) diff --git a/llm/server.go b/llm/server.go index c6f117125..adc11aaea 100644 --- a/llm/server.go +++ b/llm/server.go @@ -402,7 +402,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) } - slog.Info("starting llama server", "cmd", s.cmd.String()) + slog.Info("starting llama server", "cmd", s.cmd) if envconfig.Debug() { filteredEnv := []string{} for _, ev := range s.cmd.Env { @@ -470,7 +470,7 @@ const ( // iota is reset to 0 ServerStatusError ) -func (s ServerStatus) ToString() string { +func (s ServerStatus) String() string { switch s { case ServerStatusReady: return "llm server ready" @@ -485,12 +485,9 @@ func (s ServerStatus) ToString() string { } } -type ServerStatusResp struct { - Status string `json:"status"` - SlotsIdle int `json:"slots_idle"` - SlotsProcessing int `json:"slots_processing"` - Error string `json:"error"` - Progress float32 `json:"progress"` +type ServerStatusResponse struct { + Status ServerStatus `json:"status"` + Progress float32 `json:"progress"` } func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { @@ -502,7 +499,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { } if s.cmd.ProcessState.ExitCode() == -1 { // Most likely a signal killed it, log some more details to try to help troubleshoot - slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState.String()) + slog.Warn("llama runner process no longer running", "sys", s.cmd.ProcessState.Sys(), "string", s.cmd.ProcessState) } return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg) } @@ -527,21 +524,19 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { return ServerStatusError, fmt.Errorf("read health request: %w", err) } - var status ServerStatusResp - if err := json.Unmarshal(body, &status); err != nil { + var ssr ServerStatusResponse + if err := json.Unmarshal(body, &ssr); err != nil { return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err) } - switch status.Status { - case "ok": - return ServerStatusReady, nil - case "no slot available": - return ServerStatusNoSlotsAvailable, nil - case "loading model": - s.loadProgress = status.Progress - return ServerStatusLoadingModel, nil + switch ssr.Status { + case ServerStatusLoadingModel: + s.loadProgress = ssr.Progress + return ssr.Status, nil + case ServerStatusReady, ServerStatusNoSlotsAvailable: + return ssr.Status, nil default: - return ServerStatusError, fmt.Errorf("server error: %+v", status) + return ssr.Status, fmt.Errorf("server error: %+v", ssr) } } @@ -616,7 +611,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { status, _ := s.getServerStatus(ctx) if lastStatus != status && status != ServerStatusReady { // Only log on status changes - slog.Info("waiting for server to become available", "status", status.ToString()) + slog.Info("waiting for server to become available", "status", status) } switch status { case ServerStatusReady: @@ -630,7 +625,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { slog.Debug(fmt.Sprintf("model load progress %0.2f", s.loadProgress)) stallTimer = time.Now().Add(stallDuration) } else if !fullyLoaded && int(s.loadProgress*100.0) >= 100 { - slog.Debug("model load completed, waiting for server to become available", "status", status.ToString()) + slog.Debug("model load completed, waiting for server to become available", "status", status) stallTimer = time.Now().Add(stallDuration) fullyLoaded = true } @@ -671,63 +666,26 @@ type ImageData struct { AspectRatioID int `json:"aspect_ratio_id"` } -type completion struct { - Content string `json:"content"` - Model string `json:"model"` - Prompt string `json:"prompt"` - Stop bool `json:"stop"` - StoppedLimit bool `json:"stopped_limit"` - - Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` - } -} - type CompletionRequest struct { Prompt string Format json.RawMessage Images []ImageData Options *api.Options + + Grammar string // set before sending the request to the subprocess } type CompletionResponse struct { - Content string - DoneReason string - Done bool - PromptEvalCount int - PromptEvalDuration time.Duration - EvalCount int - EvalDuration time.Duration + Content string `json:"content"` + DoneReason string `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { - request := map[string]any{ - "prompt": req.Prompt, - "stream": true, - "n_predict": req.Options.NumPredict, - "n_keep": req.Options.NumKeep, - "main_gpu": req.Options.MainGPU, - "temperature": req.Options.Temperature, - "top_k": req.Options.TopK, - "top_p": req.Options.TopP, - "min_p": req.Options.MinP, - "typical_p": req.Options.TypicalP, - "repeat_last_n": req.Options.RepeatLastN, - "repeat_penalty": req.Options.RepeatPenalty, - "presence_penalty": req.Options.PresencePenalty, - "frequency_penalty": req.Options.FrequencyPenalty, - "mirostat": req.Options.Mirostat, - "mirostat_tau": req.Options.MirostatTau, - "mirostat_eta": req.Options.MirostatEta, - "seed": req.Options.Seed, - "stop": req.Options.Stop, - "image_data": req.Images, - "cache_prompt": true, - } - if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -735,7 +693,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu // these as "not set". break case `"json"`: - request["grammar"] = grammarJSON + req.Grammar = grammarJSON default: if req.Format[0] != '{' { return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) @@ -746,10 +704,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if g == nil { return fmt.Errorf("invalid JSON schema in format") } - request["grammar"] = string(g) + req.Grammar = string(g) } } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + if err := s.sem.Acquire(ctx, 1); err != nil { if errors.Is(err, context.Canceled) { slog.Info("aborting completion request due to client closing the connection") @@ -770,7 +733,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu if err != nil { return err } else if status != ServerStatusReady { - return fmt.Errorf("unexpected server status: %s", status.ToString()) + return fmt.Errorf("unexpected server status: %s", status) } // Handling JSON marshaling with special characters unescaped. @@ -778,7 +741,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu enc := json.NewEncoder(buffer) enc.SetEscapeHTML(false) - if err := enc.Encode(request); err != nil { + if err := enc.Encode(req); err != nil { return fmt.Errorf("failed to marshal data: %v", err) } @@ -829,7 +792,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu evt = line } - var c completion + var c CompletionResponse if err := json.Unmarshal(evt, &c); err != nil { return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } @@ -853,20 +816,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu }) } - if c.Stop { - doneReason := "stop" - if c.StoppedLimit { - doneReason = "length" - } - - fn(CompletionResponse{ - Done: true, - DoneReason: doneReason, - PromptEvalCount: c.Timings.PromptN, - PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), - EvalCount: c.Timings.PredictedN, - EvalDuration: parseDurationMs(c.Timings.PredictedMS), - }) + if c.Done { + fn(c) return nil } } @@ -914,7 +865,7 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err if err != nil { return nil, err } else if status != ServerStatusReady { - return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) + return nil, fmt.Errorf("unexpected server status: %s", status) } data, err := json.Marshal(EmbeddingRequest{Content: input}) @@ -1059,12 +1010,3 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { } return 0 } - -func parseDurationMs(ms float64) time.Duration { - dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms)) - if err != nil { - panic(err) - } - - return dur -} diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 8662afc1e..83802d604 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/runner/common" ) @@ -99,7 +100,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -163,7 +164,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // generating image embeddings for each image -func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) { var inputs []input var parts []string var matches [][]string @@ -229,7 +230,7 @@ type Server struct { image *ImageContext // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -541,75 +542,18 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -620,26 +564,28 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - var samplingParams llama.SamplingParams - samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar + // Extract options from the CompletionRequest + samplingParams := llama.SamplingParams{ + TopK: req.Options.TopK, + TopP: req.Options.TopP, + MinP: req.Options.MinP, + TypicalP: req.Options.TypicalP, + Temp: req.Options.Temperature, + RepeatLastN: req.Options.RepeatLastN, + PenaltyRepeat: req.Options.RepeatPenalty, + PenaltyFreq: req.Options.FrequencyPenalty, + PenaltyPresent: req.Options.PresencePenalty, + Mirostat: req.Options.Mirostat, + MirostatTau: req.Options.MirostatTau, + MirostatEta: req.Options.MirostatEta, + Seed: uint32(req.Options.Seed), + Grammar: req.Grammar, + } seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: req.NumKeep, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: req.Options.NumKeep, samplingParams: &samplingParams, embedding: false, }) @@ -662,7 +608,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -691,7 +637,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -702,15 +648,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numDecoded, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numDecoded, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -721,17 +669,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - var req EmbeddingRequest + var req llm.EmbeddingRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) return @@ -761,7 +700,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -782,41 +721,17 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { embedding := <-seq.embedding - if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{ Embedding: embedding, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } } -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -879,7 +794,7 @@ func (s *Server) loadModel( panic(err) } - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -937,7 +852,7 @@ func Execute(args []string) error { parallel: *parallel, seqs: make([]*Sequence, *parallel), seqsSem: semaphore.NewWeighted(int64(*parallel)), - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } var tensorSplitFloats []float32 diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index a411fddb1..adcb3f738 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -107,6 +107,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } + // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? if !cachePrompt { numPast = 0 } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c380ef221..d6339a615 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -24,6 +24,7 @@ import ( "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -94,7 +95,7 @@ type NewSequenceParams struct { embedding bool } -func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { +func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { s.ready.Wait() startTime := time.Now() @@ -145,7 +146,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { +func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) { var inputs []input.Input var parts []string var matches [][]string @@ -222,7 +223,7 @@ type Server struct { model model.Model // status for external health reporting - loading, ready to serve, etc. - status ServerStatus + status llm.ServerStatus // current progress on loading the model progress float32 @@ -501,75 +502,18 @@ func (s *Server) processBatch() error { return nil } -// TODO (jmorganca): use structs from the api package to avoid duplication -// this way the api acts as a proxy instead of using a different api for the -// runner -type Options struct { - api.Runner - - NumKeep int `json:"n_keep"` - Seed int `json:"seed"` - NumPredict int `json:"n_predict"` - TopK int `json:"top_k"` - TopP float32 `json:"top_p"` - MinP float32 `json:"min_p"` - TypicalP float32 `json:"typical_p"` - RepeatLastN int `json:"repeat_last_n"` - Temperature float32 `json:"temperature"` - RepeatPenalty float32 `json:"repeat_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - FrequencyPenalty float32 `json:"frequency_penalty"` - Mirostat int `json:"mirostat"` - MirostatTau float32 `json:"mirostat_tau"` - MirostatEta float32 `json:"mirostat_eta"` - Stop []string `json:"stop"` -} - -type ImageData struct { - Data []byte `json:"data"` - ID int `json:"id"` - AspectRatioID int `json:"aspect_ratio_id"` -} - -type CompletionRequest struct { - Prompt string `json:"prompt"` - Images []ImageData `json:"image_data"` - Grammar string `json:"grammar"` - CachePrompt bool `json:"cache_prompt"` - - Options -} - -type Timings struct { - PredictedN int `json:"predicted_n"` - PredictedMS float64 `json:"predicted_ms"` - PromptN int `json:"prompt_n"` - PromptMS float64 `json:"prompt_ms"` -} - -type CompletionResponse struct { - Content string `json:"content"` - Stop bool `json:"stop"` - - Model string `json:"model,omitempty"` - Prompt string `json:"prompt,omitempty"` - StoppedLimit bool `json:"stopped_limit,omitempty"` - PredictedN int `json:"predicted_n,omitempty"` - PredictedMS float64 `json:"predicted_ms,omitempty"` - PromptN int `json:"prompt_n,omitempty"` - PromptMS float64 `json:"prompt_ms,omitempty"` - - Timings Timings `json:"timings"` -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { - var req CompletionRequest - req.Options = Options(api.DefaultOptions()) + var req llm.CompletionRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "Bad request", http.StatusBadRequest) return } + if req.Options == nil { + opts := api.DefaultOptions() + req.Options = &opts + } + // Set the headers to indicate streaming w.Header().Set("Content-Type", "application/json") w.Header().Set("Transfer-Encoding", "chunked") @@ -591,18 +535,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } sampler := sample.NewSampler( - req.Temperature, - req.TopK, - req.TopP, - req.MinP, - req.Seed, + req.Options.Temperature, + req.Options.TopK, + req.Options.TopP, + req.Options.MinP, + req.Options.Seed, grammar, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.NumPredict, - stop: req.Stop, - numKeep: int32(req.NumKeep), + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), sampler: sampler, embedding: false, }) @@ -625,7 +569,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) @@ -652,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&CompletionResponse{ + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -663,15 +607,17 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { // Send the final response - if err := json.NewEncoder(w).Encode(&CompletionResponse{ - Stop: true, - StoppedLimit: seq.doneReason == "limit", - Timings: Timings{ - PromptN: seq.numPromptInputs, - PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), - PredictedN: seq.numPredicted, - PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), - }, + doneReason := "stop" + if seq.doneReason == "limit" { + doneReason = "length" + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + Done: true, + DoneReason: doneReason, + PromptEvalCount: seq.numPromptInputs, + PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime), + EvalCount: seq.numPredicted, + EvalDuration: time.Since(seq.startGenerationTime), }); err != nil { http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) } @@ -682,43 +628,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } -type EmbeddingRequest struct { - Content string `json:"content"` - CachePrompt bool `json:"cache_prompt"` -} - -type EmbeddingResponse struct { - Embedding []float32 `json:"embedding"` -} - -type HealthResponse struct { - Status string `json:"status"` - Progress float32 `json:"progress"` -} - -type ServerStatus int - -const ( - ServerStatusReady ServerStatus = iota - ServerStatusLoadingModel - ServerStatusError -) - -func (s ServerStatus) ToString() string { - switch s { - case ServerStatusReady: - return "ok" - case ServerStatusLoadingModel: - return "loading model" - default: - return "server error" - } -} - func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(&HealthResponse{ - Status: s.status.ToString(), + if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{ + Status: s.status, Progress: s.progress, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) @@ -772,7 +685,7 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - s.status = ServerStatusReady + s.status = llm.ServerStatusReady s.ready.Done() } @@ -824,7 +737,7 @@ func Execute(args []string) error { server := &Server{ batchSize: *batchSize, - status: ServerStatusLoadingModel, + status: llm.ServerStatusLoadingModel, } // TODO(jessegross): Parameters that need to be implemented: From 9679f40146ed7e148a64abb14c7c619b162c1875 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 12 Mar 2025 16:56:11 -0700 Subject: [PATCH 2/8] ml: Allow models to constrain inputs to a single batch Models may require that a set of inputs all be processed as part of the same batch. For example, if an image has multiple patches with fully connected attention between them, we should not split the batch in the middle of an image. Fixes #9697 --- integration/llm_image_test.go | 29 +++++++++++++++++ model/input/input.go | 6 ++++ model/models/gemma3/model.go | 30 +++++------------ model/models/gemma3/model_text.go | 53 ++++++------------------------- runner/ollamarunner/runner.go | 12 ++++++- 5 files changed, 64 insertions(+), 66 deletions(-) diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index c7b56890e..fbbd9d5ce 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -66,6 +66,35 @@ func TestIntegrationMllama(t *testing.T) { DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) } +func TestIntegrationSplitBatch(t *testing.T) { + image, err := base64.StdEncoding.DecodeString(imageEncoding) + require.NoError(t, err) + req := api.GenerateRequest{ + Model: "gemma3:4b", + // Fill up a chunk of the batch so the image will partially spill over into the next one + System: "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed aliquet, justo in malesuada lobortis, odio ligula volutpat quam, quis faucibus ipsum magna quis sapien. Aliquam in venenatis diam, eu viverra magna. Phasellus imperdiet hendrerit volutpat. Vivamus sem ex, facilisis placerat felis non, dictum elementum est. Phasellus aliquam imperdiet lacus, eget placerat ligula sodales vel. Pellentesque nec auctor mi. Curabitur arcu nisi, faucibus eget nunc id, viverra interdum mi. Curabitur ornare ipsum ex, ac euismod ex aliquam in. Vestibulum id magna at purus accumsan fermentum. Proin scelerisque posuere nunc quis interdum. Maecenas sed mollis nisl. Etiam vitae ipsum interdum, placerat est quis, tincidunt velit. Nullam tempor nibh non lorem volutpat efficitur. Cras laoreet diam imperdiet ipsum auctor bibendum. Suspendisse ultrices urna sed metus sagittis suscipit. Quisque ullamcorper aliquam nibh ut mollis. Aenean dapibus mauris pharetra, venenatis elit ac, hendrerit odio. Cras vestibulum erat tempor, lobortis justo eu, lobortis ipsum. Nam laoreet dapibus sem. Proin vel diam ultrices, elementum ante et, ornare lectus. Proin eu accumsan nisl. Praesent ac ex vitae ipsum vulputate tristique facilisis sit amet lacus. Nullam faucibus magna a pellentesque pretium. Nunc lacinia ullamcorper sollicitudin. Donec vitae accumsan turpis, sed porttitor est. Donec porttitor mi vitae augue faucibus, vel mollis diam tincidunt.", + Prompt: "what does the text in this image say?", + Stream: &stream, + Options: map[string]interface{}{ + "seed": 42, + "temperature": 0.0, + }, + Images: []api.ImageData{ + image, + }, + } + + // Note: sometimes it returns "the ollamas" sometimes "the ollams" + resp := "the ollam" + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + require.NoError(t, PullIfMissing(ctx, client, req.Model)) + // llava models on CPU can be quite slow to start, + DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second) +} + const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb AAUAAAABAAAAUgEoAAMAAAABAAIAAIdpAAQAAAABAAAAWgAAAAAAAABIAAAAAQAAAEgAAAABAAOgAQADAAAAAQABAACgAgAEAAAAAQAAANKgAwAEAAAAAQAA AHgAAAAAXdsepgAAAAlwSFlzAAALEwAACxMBAJqcGAAAAVlpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6 diff --git a/model/input/input.go b/model/input/input.go index 0cb3f3f41..30bdcf065 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -15,6 +15,12 @@ type Input struct { // stored in Multimodal, used for caching and comparing // equality. MultimodalHash uint64 + + // SameBatch forces the following number of tokens to be processed + // in a single batch, breaking and extending batches as needed. + // Useful for things like images that must be processed in one + // shot. + SameBatch int } // MultimodalIndex is a multimodal element (such as an image) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 24193f15f..ccc7567c5 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -2,10 +2,9 @@ package gemma3 import ( "bytes" - "encoding/binary" - "hash/fnv" "image" "math" + "slices" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -112,36 +111,23 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } -type imageToken struct { - embedding ml.Tensor - index int -} - func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { var result []input.Input - fnvHash := fnv.New64a() for _, inp := range inputs { if inp.Multimodal == nil { result = append(result, inp) } else { - imageInputs := []input.Input{ - {Token: 108}, // "\n\n" - {Token: 255999}, // """ - } - result = append(result, imageInputs...) - - // add image embeddings inputMultimodal := inp.Multimodal.(ml.Tensor) - for i := range inputMultimodal.Dim(1) { - fnvHash.Reset() - binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash) - fnvHash.Write([]byte{byte(i)}) + result = append(result, + input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + input.Input{Token: 255999}, // """ + input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + ) - imageToken := imageToken{embedding: inputMultimodal, index: i} - result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()}) - } + // add image token placeholders + result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, input.Input{Token: 256000}, // diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 7a88c0921..567f65a5e 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -171,53 +171,20 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int { - var embedding ml.Tensor - var src, dst, length int - var except []int - - for _, image := range multimodal { - imageToken := image.Multimodal.(imageToken) - imageSrc := imageToken.index - imageDst := image.Index - - if embedding == nil { - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst { - src = imageSrc - dst = imageDst - length++ - } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst { - length++ - } else { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - - embedding = imageToken.embedding - src = imageSrc - dst = imageDst - length = 1 - } - - except = append(except, imageDst) - } - - if embedding != nil { - visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) - } - - return except -} - func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal) + // set image embeddings + var except []int + for _, image := range opts.Multimodal { + visionOutputs := image.Multimodal.(ml.Tensor) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + + for i := range visionOutputs.Dim(1) { + except = append(except, image.Index+i) + } + } for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d6339a615..916ad45da 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -352,6 +352,8 @@ func (s *Server) processBatch() error { seq.cache.Inputs = []input.Input{} } + batchSize := s.batchSize + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { @@ -364,7 +366,15 @@ func (s *Server) processBatch() error { } } - if j >= s.batchSize { + // If we are required to put following inputs into a single batch then extend the + // batch size. Since we are only extending the size the minimum amount possible, this + // will cause a break if we have pending inputs. + minBatch := 1 + inp.SameBatch + if minBatch > batchSize { + batchSize = minBatch + } + + if len(seq.pendingInputs)+minBatch > batchSize { break } From 282bfaaa957fff8402eb7a4b1657f57c49939604 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 13 Mar 2025 20:32:50 -0700 Subject: [PATCH 3/8] ollamarunner: Use a separate context per multimodal input Currently there is a single context per sequence, shared all by all multimodal inputs. Since we build a vision encoder graph per image, with a large number of inputs we can eventually hit the maximum number of graph nodes per context. This changes to use a separate context for each image, ensuring that available resource limits are consistent. --- model/model.go | 2 +- model/models/gemma3/model.go | 2 +- model/models/mllama/model.go | 11 +++++++---- runner/ollamarunner/runner.go | 37 +++++++++++++++++++++++------------ 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/model/model.go b/model/model.go index fadea3246..53e47add9 100644 --- a/model/model.go +++ b/model/model.go @@ -60,7 +60,7 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []input.Input) ([]input.Input, error) + PostTokenize([]input.Input) ([]input.Input, error) } // Base implements the common fields and methods for all models diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index ccc7567c5..32ad80f43 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -111,7 +111,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input for _, inp := range inputs { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 071d77ac7..fa4d570ca 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var images []input.Input fnvHash := fnv.New64a() for i := range inputs { if inputs[i].Multimodal == nil { if len(images) > 0 { - inputs[i].Multimodal = images[0].Multimodal + inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)} inputs[i].MultimodalHash = images[0].MultimodalHash for j := 1; j < len(images); j++ { - inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) + inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor)) fnvHash.Reset() binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash) @@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor if len(opts.Multimodal) > 0 { - crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) + images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) + if len(images) > 0 { + crossAttentionStates = images[len(images)-1] + } } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 916ad45da..d4c24556c 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -34,10 +34,14 @@ import ( _ "github.com/ollama/ollama/model/models" ) +type contextList struct { + list []ml.Context +} + type Sequence struct { - // ctx for allocating tensors that last the lifetime of the sequence, such as + // ctxs are used for allocating tensors that last the lifetime of the sequence, such as // multimodal embeddings - ctx ml.Context + ctxs *contextList // batch index iBatch int @@ -99,9 +103,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe s.ready.Wait() startTime := time.Now() - ctx := s.model.Backend().NewContext() - inputs, err := s.inputs(ctx, prompt, images) + inputs, ctxs, err := s.inputs(prompt, images) if err != nil { return nil, fmt.Errorf("failed to process inputs: %w", err) } else if len(inputs) == 0 { @@ -127,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // TODO(jessegross): Ingest cached history for grammar return &Sequence{ - ctx: ctx, + ctxs: ctxs, inputs: inputs, numPromptInputs: len(inputs), startProcessingTime: startTime, @@ -146,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) { +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) { var inputs []input.Input var parts []string var matches [][]string @@ -161,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ( parts = []string{prompt} } + var contexts contextList + runtime.AddCleanup(&contexts, func(ctxs []ml.Context) { + for _, ctx := range ctxs { + ctx.Close() + } + }, contexts.list) + postTokenize := false for i, part := range parts { // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { - return nil, err + return nil, nil, err } for _, t := range tokens { @@ -186,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ( } if imageIndex < 0 { - return nil, fmt.Errorf("invalid image index: %d", n) + return nil, nil, fmt.Errorf("invalid image index: %d", n) } + ctx := s.model.Backend().NewContext() + contexts.list = append(contexts.list, ctx) imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data) if err != nil { - return nil, err + return nil, nil, err } s.multimodalHash.Reset() @@ -205,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ( if visionModel && postTokenize { var err error - inputs, err = multimodalProcessor.PostTokenize(ctx, inputs) + inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { - return nil, err + return nil, nil, err } } - return inputs, nil + return inputs, &contexts, nil } type Server struct { @@ -306,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) { close(seq.responses) close(seq.embedding) seq.cache.InUse = false - seq.ctx.Close() s.seqs[seqIndex] = nil s.seqsSem.Release(1) } From 7bf793a6007ca11fae0180ea6f2ebd7258428bd4 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 12 Mar 2025 16:59:23 -0700 Subject: [PATCH 4/8] gemma3: Allow multiple image in a single input Previously processing multiple images in a batch would trigger segfaults so sending images together was disabled as a way to mitigate this. The trigger was processing one image on the CPU and one on the GPU. This can no longer happen: - The vision encoder is now on the GPU so both images would be processed on the GPU. - We require images to be fully contained in a batch and each image including its special tokens is over half the batch size. As a result, we will never get two images in the same batch. Fixes #9731 --- server/prompt.go | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/server/prompt.go b/server/prompt.go index d053f2a8d..5b5b958f1 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var system []api.Message isMllama := checkMllamaModelFamily(m) - isGemma3 := checkGemma3ModelFamily(m) var imageNumTokens int // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent @@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n; i >= 0; i-- { - if (isMllama || isGemma3) && len(msgs[i].Images) > 1 { + if isMllama && len(msgs[i].Images) > 1 { return "", nil, errTooManyImages } @@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool { } return false } - -func checkGemma3ModelFamily(m *Model) bool { - for _, arch := range m.Config.ModelFamilies { - if arch == "gemma3" { - return true - } - } - return false -} From 2d2247e59e3995c00bf6bdf9bd2713bdf01f6921 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 14 Mar 2025 15:44:08 -0700 Subject: [PATCH 5/8] Align versions for local builds (#9635) Darwin was using a different pattern for the version string than linux or windows. --- scripts/build_darwin.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh index 76d0a6c2b..616e8501c 100755 --- a/scripts/build_darwin.sh +++ b/scripts/build_darwin.sh @@ -8,7 +8,7 @@ usage() { exit 1 } -export VERSION=${VERSION:-$(git describe --tags --dirty)} +export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")} export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" export CGO_CPPFLAGS='-mmacosx-version-min=11.3' From ef378ad673a3f01382add316835957b1d4184177 Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Fri, 14 Mar 2025 17:41:07 -0700 Subject: [PATCH 6/8] gemma3 quantization (#9776) --- llama/llama.cpp/src/llama-arch.cpp | 19 ++++ llama/llama.cpp/src/llama-arch.h | 1 + llama/llama.cpp/src/llama-model.cpp | 7 ++ llama/llama.cpp/src/llama-quant.cpp | 9 ++ llama/patches/0021-gemma3-quantization.patch | 113 +++++++++++++++++++ 5 files changed, 149 insertions(+) create mode 100644 llama/patches/0021-gemma3-quantization.patch diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b6f20286b..b443fcd3f 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MINICPM3, "minicpm3" }, { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index ec7422244..aad92a5d2 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -41,6 +41,7 @@ enum llm_arch { LLM_ARCH_MINICPM3, LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, + LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index ab1a07d10..701830418 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA3: + { + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA3: + { + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: + case LLM_ARCH_GEMMA3: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index 6eb1da08e..d2f3a5108 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // This used to be a regex, but has an extreme cost to compile times. bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + // don't quantize vision stuff + quantize &= name.find("v.blk.") == std::string::npos; + + quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; + quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; + quantize &= name.find("v.patch_embedding.weight") == std::string::npos; + quantize &= name.find("v.position_embedding.weight") == std::string::npos; + quantize &= name.find("v.post_layernorm.weight") == std::string::npos; + // quantize only 2D and 3D tensors (experts) quantize &= (ggml_n_dims(tensor) >= 2); diff --git a/llama/patches/0021-gemma3-quantization.patch b/llama/patches/0021-gemma3-quantization.patch new file mode 100644 index 000000000..4f6dbc11b --- /dev/null +++ b/llama/patches/0021-gemma3-quantization.patch @@ -0,0 +1,113 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Patrick Devine +Date: Fri, 14 Mar 2025 16:33:23 -0700 +Subject: [PATCH] gemma3 quantization + +--- + src/llama-arch.cpp | 19 +++++++++++++++++++ + src/llama-arch.h | 1 + + src/llama-model.cpp | 7 +++++++ + src/llama-quant.cpp | 9 +++++++++ + 4 files changed, 36 insertions(+) + +diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp +index b6f20286..b443fcd3 100644 +--- a/src/llama-arch.cpp ++++ b/src/llama-arch.cpp +@@ -37,6 +37,7 @@ static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, ++ { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, +@@ -804,6 +805,24 @@ static const std::map> LLM_TENSOR_N + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, ++ { ++ LLM_ARCH_GEMMA3, ++ { ++ { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, ++ { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, ++ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, ++ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, ++ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, ++ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, ++ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, ++ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, ++ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, ++ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, ++ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, ++ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, ++ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, ++ }, ++ }, + { + LLM_ARCH_STARCODER2, + { +diff --git a/src/llama-arch.h b/src/llama-arch.h +index ec742224..aad92a5d 100644 +--- a/src/llama-arch.h ++++ b/src/llama-arch.h +@@ -41,6 +41,7 @@ enum llm_arch { + LLM_ARCH_MINICPM3, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, ++ LLM_ARCH_GEMMA3, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, +diff --git a/src/llama-model.cpp b/src/llama-model.cpp +index ab1a07d1..70183041 100644 +--- a/src/llama-model.cpp ++++ b/src/llama-model.cpp +@@ -878,6 +878,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { + default: type = LLM_TYPE_UNKNOWN; + } + } break; ++ case LLM_ARCH_GEMMA3: ++ { ++ } break; + case LLM_ARCH_STARCODER2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +@@ -2537,6 +2540,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; ++ case LLM_ARCH_GEMMA3: ++ { ++ } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); +@@ -4029,6 +4035,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { + case LLM_ARCH_PHIMOE: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: ++ case LLM_ARCH_GEMMA3: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: + case LLM_ARCH_GPTNEOX: +diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp +index 6eb1da08..d2f3a510 100644 +--- a/src/llama-quant.cpp ++++ b/src/llama-quant.cpp +@@ -737,6 +737,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + ++ // don't quantize vision stuff ++ quantize &= name.find("v.blk.") == std::string::npos; ++ ++ quantize &= name.find("mm.mm_input_projection.weight") == std::string::npos; ++ quantize &= name.find("mm.mm_soft_emb_norm.weight") == std::string::npos; ++ quantize &= name.find("v.patch_embedding.weight") == std::string::npos; ++ quantize &= name.find("v.position_embedding.weight") == std::string::npos; ++ quantize &= name.find("v.post_layernorm.weight") == std::string::npos; ++ + // quantize only 2D and 3D tensors (experts) + quantize &= (ggml_n_dims(tensor) >= 2); + From 8294676150dde3dd3d316644b86e42b07204ff9c Mon Sep 17 00:00:00 2001 From: Blake Mizerany Date: Fri, 14 Mar 2025 18:33:07 -0700 Subject: [PATCH 7/8] server/internal/client/ollama: set User-Agent for registry client (#9775) This sets the agent header in DefaultRegistry to include the version of the client, OS, and architecture in the previous format, with a minor twist. Note: The version is obtained from the build info, instead of the version in version.Version, which should not longer be necessary, but we can remove in a future commit. Using the build info is more accurate and also provides extra build information if the build is not tagged, and if it is "dirty". Previously, the version was just "0.0.0" with no other helpful information. The ollama.com registry and others handle this swimmingly. --- server/internal/client/ollama/registry.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index cf05f79ae..d1d01ba46 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -25,6 +25,7 @@ import ( "os" "path/filepath" "runtime" + "runtime/debug" "slices" "strconv" "strings" @@ -259,6 +260,7 @@ func DefaultRegistry() (*Registry, error) { } var rc Registry + rc.UserAgent = UserAgent() rc.Key, err = ssh.ParseRawPrivateKey(keyPEM) if err != nil { return nil, err @@ -274,6 +276,16 @@ func DefaultRegistry() (*Registry, error) { return &rc, nil } +func UserAgent() string { + buildinfo, _ := debug.ReadBuildInfo() + return fmt.Sprintf("ollama/%s (%s %s) Go/%s", + buildinfo.Main.Version, + runtime.GOARCH, + runtime.GOOS, + runtime.Version(), + ) +} + func (r *Registry) maxStreams() int { return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0)) } From 2c8b4846437747bd23e7a176f83011e39ec2128b Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Sat, 15 Mar 2025 12:09:02 -0700 Subject: [PATCH 8/8] fix: correctly save in interactive mode (#9788) This fixes the case where a FROM line in previous modelfile points to a file which may/may not be present in a different ollama instance. We shouldn't be relying on the filename though and instead just check if the FROM line was instead a valid model name and point to that instead. --- cmd/cmd_test.go | 129 +++++++++++++++++++++++++++++++++++++++++++++ cmd/interactive.go | 12 ++++- 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index f21a8f50b..41b03e1bd 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -757,3 +757,132 @@ func TestCreateHandler(t *testing.T) { }) } } + +func TestNewCreateRequest(t *testing.T) { + tests := []struct { + name string + from string + opts runOptions + expected *api.CreateRequest + }{ + { + "basic test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "", + Prompt: "You are a fun AI agent", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "parent model test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + }, + }, + { + "parent model as filepath test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "/some/file/like/etc/passwd", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "parent model as windows filepath test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "D:\\some\\file\\like\\etc\\passwd", + Messages: []api.Message{}, + WordWrap: true, + }, + &api.CreateRequest{ + From: "mymodel", + Model: "newmodel", + }, + }, + { + "options test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + Options: map[string]any{ + "temperature": 1.0, + }, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + Parameters: map[string]any{ + "temperature": 1.0, + }, + }, + }, + { + "messages test", + "newmodel", + runOptions{ + Model: "mymodel", + ParentModel: "parentmodel", + System: "You are a fun AI agent", + Messages: []api.Message{ + { + Role: "user", + Content: "hello there!", + }, + { + Role: "assistant", + Content: "hello to you!", + }, + }, + WordWrap: true, + }, + &api.CreateRequest{ + From: "parentmodel", + Model: "newmodel", + System: "You are a fun AI agent", + Messages: []api.Message{ + { + Role: "user", + Content: "hello there!", + }, + { + Role: "assistant", + Content: "hello to you!", + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := NewCreateRequest(tt.from, tt.opts) + if !cmp.Equal(actual, tt.expected) { + t.Errorf("expected output %#v, got %#v", tt.expected, actual) + } + }) + } +} diff --git a/cmd/interactive.go b/cmd/interactive.go index f3489b652..d85510d45 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -18,6 +18,7 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/types/errtypes" + "github.com/ollama/ollama/types/model" ) type MultilineState int @@ -459,9 +460,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } func NewCreateRequest(name string, opts runOptions) *api.CreateRequest { + parentModel := opts.ParentModel + + modelName := model.ParseName(parentModel) + if !modelName.IsValid() { + parentModel = "" + } + req := &api.CreateRequest{ - Name: name, - From: cmp.Or(opts.ParentModel, opts.Model), + Model: name, + From: cmp.Or(parentModel, opts.Model), } if opts.System != "" {