Merge remote-tracking branch 'upstream/main' into vulkanV3
This commit is contained in:
@@ -1,18 +1,19 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -21,10 +22,13 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -34,14 +38,13 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models"
|
||||
)
|
||||
|
||||
type contextList struct {
|
||||
list []ml.Context
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
|
||||
// multimodal embeddings
|
||||
ctxs *contextList
|
||||
ctxs []ml.Context
|
||||
|
||||
// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
|
||||
mmStore multimodalStore
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
@@ -82,7 +85,7 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
doneReason string
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
@@ -104,7 +107,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, ctxs, err := s.inputs(prompt, images)
|
||||
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
} else if len(inputs) == 0 {
|
||||
@@ -123,8 +126,36 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
||||
sameBatch := 0
|
||||
for i, inp := range inputs {
|
||||
if sameBatch > 0 {
|
||||
sameBatch--
|
||||
|
||||
if promptStart == int32(i) {
|
||||
promptStart++
|
||||
}
|
||||
} else if promptStart == int32(i) {
|
||||
break
|
||||
}
|
||||
|
||||
if inp.SameBatch != 0 {
|
||||
if int32(i) < params.numKeep {
|
||||
return nil, fmt.Errorf("SameBatch may not be specified within numKeep (index: %v numKeep: %v SameBatch: %v)", i, params.numKeep, inp.SameBatch)
|
||||
}
|
||||
|
||||
sameBatch = inp.SameBatch
|
||||
}
|
||||
}
|
||||
|
||||
if promptStart >= int32(len(inputs)) {
|
||||
return nil, errors.New("entire prompt removed by truncation")
|
||||
}
|
||||
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
newInputs = append(newInputs, inputs[promptStart:]...)
|
||||
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||
inputs = newInputs
|
||||
@@ -134,6 +165,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
return &Sequence{
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
@@ -152,8 +184,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
|
||||
@@ -163,23 +198,17 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
mmStore = newMultimodalStore()
|
||||
} else {
|
||||
parts = []string{prompt}
|
||||
}
|
||||
|
||||
var contexts contextList
|
||||
runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
|
||||
for _, ctx := range ctxs {
|
||||
ctx.Close()
|
||||
}
|
||||
}, contexts.list)
|
||||
|
||||
postTokenize := false
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
@@ -199,20 +228,23 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
||||
}
|
||||
|
||||
if imageIndex < 0 {
|
||||
return nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
contexts.list = append(contexts.list, ctx)
|
||||
runtime.SetFinalizer(ctx, func(c ml.Context) { c.Close() })
|
||||
ctxs = append(ctxs, ctx)
|
||||
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
s.multimodalHash.Reset()
|
||||
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
|
||||
imageHash := s.multimodalHash.Sum64()
|
||||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
@@ -222,11 +254,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
|
||||
var err error
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return inputs, &contexts, nil
|
||||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
@@ -267,15 +299,12 @@ type Server struct {
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
|
||||
// vocab is a llama.cpp vocab required for gammar-based
|
||||
// constrained generation (json mode, structured outputs)
|
||||
// TODO: this is temporary until Ollama sampling supports
|
||||
// constrained generation
|
||||
vocab *sample.Vocab
|
||||
}
|
||||
|
||||
func (s *Server) allNil() bool {
|
||||
@@ -313,7 +342,7 @@ func flushPending(seq *Sequence) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
flushPending(seq)
|
||||
@@ -348,16 +377,25 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var options input.Options
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var batchInputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(i, "limit")
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -368,19 +406,34 @@ func (s *Server) processBatch() error {
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
||||
for j, inp := range seq.inputs {
|
||||
for i, inp := range seq.inputs {
|
||||
// If we are required to put following inputs into a single batch then extend the
|
||||
// batch size. Since we are only extending the size the minimum amount possible, this
|
||||
// will cause a break if we have pending inputs.
|
||||
// will cause a break if we have existing inputs.
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > batchSize {
|
||||
batchSize = minBatch
|
||||
}
|
||||
|
||||
if len(seq.pendingInputs)+minBatch > batchSize {
|
||||
// Stop if the required batch would put us over the total batch size (including tokens
|
||||
// added by other sequences). If we haven't been able to add anything yet then pick up
|
||||
// here again for the next batch to avoid starvation, though we can opportunistically
|
||||
// check if other sequences can still squeeze something in.
|
||||
if len(batchInputs)+minBatch > batchSize {
|
||||
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
|
||||
resumeSeq = seqIdx
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// If the sum of our working set (already processed tokens, tokens we added to this
|
||||
// batch, required following tokens) exceeds the context size, then trigger a shift
|
||||
// now so we don't have to do one later when we can't break the batch.
|
||||
@@ -391,21 +444,33 @@ func (s *Server) processBatch() error {
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
|
||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(options.Outputs)
|
||||
if j+1 == len(seq.inputs) {
|
||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
@@ -413,14 +478,17 @@ func (s *Server) processBatch() error {
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if len(options.Inputs) == 0 {
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
s.nextSeq = seqIdx + 1
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
@@ -455,12 +523,12 @@ func (s *Server) processBatch() error {
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
s.removeSequence(i, "")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
@@ -473,7 +541,7 @@ func (s *Server) processBatch() error {
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -509,7 +577,7 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, "stop")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -522,7 +590,7 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, "connection")
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,14 +619,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var grammar *sample.Grammar
|
||||
var grammar *sample.GrammarSampler
|
||||
var err error
|
||||
if req.Grammar != "" {
|
||||
grammar, err = sample.NewGrammar(s.vocab, req.Grammar)
|
||||
grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer grammar.Free()
|
||||
}
|
||||
|
||||
sampler := sample.NewSampler(
|
||||
@@ -587,7 +656,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting completion request due to client closing the connection")
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -599,6 +668,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -612,6 +682,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -633,14 +704,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
// Send the final response
|
||||
doneReason := "stop"
|
||||
if seq.doneReason == "limit" {
|
||||
doneReason = "length"
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: doneReason,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
EvalCount: seq.numPredicted,
|
||||
@@ -676,7 +742,104 @@ func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
// - Encode a 2048x2048 image. This assumes that a single image of this
|
||||
// size is sufficient to trigger the worst case. This is currently true
|
||||
// because for existing models, only a single image fits in a batch.
|
||||
// - Add the embedding to a full batch of tokens - this is necessary because
|
||||
// the model may be looking for non-image data, such as <image> tags.
|
||||
// - Run PostTokenize to execute any transformations between generated
|
||||
// embeddings and what the forward pass expects.
|
||||
// - The result may now be larger than a batch (images may not fit in a
|
||||
// single batch), so trim based on what will fit and must be grouped together.
|
||||
// - Fill out the rest of the space with text tokens.
|
||||
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
|
||||
mmCtx := s.model.Backend().NewContext()
|
||||
defer mmCtx.Close()
|
||||
|
||||
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
|
||||
var buf bytes.Buffer
|
||||
bmp.Encode(&buf, img)
|
||||
|
||||
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
|
||||
mmStore.addMultimodal(inputs[0].Multimodal)
|
||||
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, inp := range inputs {
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > s.batchSize {
|
||||
inputs = inputs[i:min(i+minBatch, len(inputs))]
|
||||
break
|
||||
} else if i+minBatch > s.batchSize {
|
||||
inputs = inputs[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var batch input.Batch
|
||||
|
||||
batchInputs := make([]int32, len(inputs))
|
||||
batch.Positions = make([]int32, len(inputs))
|
||||
batch.Sequences = make([]int, len(inputs))
|
||||
for i, inp := range inputs {
|
||||
batchInputs[i] = inp.Token
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
|
||||
}
|
||||
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := s.model.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Forward(t).Reserve()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) initModel(
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
@@ -684,23 +847,21 @@ func (s *Server) loadModel(
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
) error {
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.vocab = sample.NewVocab(mpath)
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
panic("loras are not yet implemented")
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.cache.enabled && parallel > 1 {
|
||||
@@ -712,6 +873,43 @@ func (s *Server) loadModel(
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
return s.reserveWorstCaseGraph()
|
||||
}
|
||||
|
||||
func (s *Server) load(
|
||||
ctx context.Context,
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
|
||||
if err != nil {
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
// We can't yet handle this but in the future we will
|
||||
s.cache.Close()
|
||||
if s.model != nil {
|
||||
s.model.Backend().Close()
|
||||
}
|
||||
}
|
||||
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
|
||||
|
||||
err = s.model.Backend().Load(ctx,
|
||||
func(progress float32) {
|
||||
s.progress = progress
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
@@ -728,9 +926,8 @@ func Execute(args []string) error {
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
@@ -744,22 +941,7 @@ func Execute(args []string) error {
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
level := slog.LevelInfo
|
||||
if *verbose {
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
return attr
|
||||
},
|
||||
})
|
||||
slog.SetDefault(slog.New(handler))
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("starting ollama engine")
|
||||
|
||||
server := &Server{
|
||||
@@ -767,9 +949,14 @@ func Execute(args []string) error {
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
server.ready.Add(1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// no-mmap
|
||||
// mlock
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
@@ -789,14 +976,7 @@ func Execute(args []string) error {
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
|
||||
Reference in New Issue
Block a user