Compare commits
8 Commits
v0.5.4
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac300ec32b | ||
|
|
444640f3c7 | ||
|
|
e515ac0595 | ||
|
|
a4b32736cf | ||
|
|
d0769313ed | ||
|
|
4537a89b26 | ||
|
|
85822544a9 | ||
|
|
08a832b482 |
77
cmd/cmd.go
77
cmd/cmd.go
@@ -8,6 +8,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
@@ -17,9 +18,11 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -30,11 +33,13 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
@@ -42,6 +47,7 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -516,6 +522,64 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
// unknownKey handles key validation when a connection fails due to an unknown key.
|
||||
// It attempts to open the browser for interactive sessions to let users connect their key,
|
||||
// falling back to command-line instructions for non-interactive sessions.
|
||||
// Returns nil if browser flow succeeds, or an error with connection instructions otherwise.
|
||||
func unknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if term.IsTerminal(int(os.Stdout.Fd())) && !envconfig.Noninteractive() {
|
||||
// URL encode the key and device name for the browser URL
|
||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(localPubKey))
|
||||
d, _ := os.Hostname()
|
||||
encodedDevice := url.QueryEscape(d)
|
||||
browserURL := fmt.Sprintf("https://ollama.com/connect?host=%s&key=%s", encodedDevice, encodedKey)
|
||||
if err := browser.OpenURL(browserURL); err == nil {
|
||||
fmt.Println("Opening browser to connect your device...")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -564,10 +628,19 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if p != nil {
|
||||
p.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
return unknownKey(err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
@@ -578,7 +651,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
destination := n.String()
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
if isOllamaHost {
|
||||
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||
}
|
||||
fmt.Printf("\nYou can find your model at:\n\n")
|
||||
@@ -1474,6 +1547,8 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||
})
|
||||
case pushCmd:
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_NONINTERACTIVE"]})
|
||||
default:
|
||||
appendEnvDocs(cmd, envs)
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@@ -368,15 +369,13 @@ func TestGetModelfileName(t *testing.T) {
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
name: "successful push",
|
||||
modelName: "test-model",
|
||||
modelName: "successful-push",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -389,8 +388,8 @@ func TestPushHandler(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
if req.Name != "successful-push" {
|
||||
t.Errorf("expected model name 'successful-push', got %s", req.Name)
|
||||
}
|
||||
|
||||
// Simulate progress updates
|
||||
@@ -409,11 +408,10 @@ func TestPushHandler(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
|
||||
},
|
||||
{
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
modelName: "unauthorized-push",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -428,10 +426,29 @@ func TestPushHandler(t *testing.T) {
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
{
|
||||
modelName: "unknown-key-err",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
uerr := errtypes.UnknownOllamaKey{
|
||||
Key: "aaa",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": uerr.Error(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "unauthorized: unknown ollama key \"aaa\"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run(tt.modelName, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||
handler(w, r)
|
||||
|
||||
@@ -165,6 +165,9 @@ var (
|
||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||
// Noninteractive is true when CLI interactive features should be disabled.
|
||||
// This affects features like automatic browser opening.
|
||||
Noninteractive = Bool("OLLAMA_NONINTERACTIVE")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -250,6 +253,7 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_NONINTERACTIVE": {"OLLAMA_NONINTERACTIVE", Noninteractive(), "Disable interactive CLI features, such as automatically opening the browser"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
1
go.mod
1
go.mod
@@ -36,6 +36,7 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
||||
3
go.sum
3
go.sum
@@ -159,6 +159,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -281,6 +283,7 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
|
||||
99
llama/llama.cpp
vendored
99
llama/llama.cpp
vendored
@@ -3051,6 +3051,13 @@ struct llama_kv_cache {
|
||||
}
|
||||
};
|
||||
|
||||
// block of KV slots to move when defragging
|
||||
struct llama_kv_defrag_move {
|
||||
uint32_t src;
|
||||
uint32_t dst;
|
||||
uint32_t len;
|
||||
};
|
||||
|
||||
struct llama_control_vector {
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
@@ -10828,35 +10835,23 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == ids.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
for (const auto & move : moves) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
n_embd_k_gqa, nm,
|
||||
n_embd_k_gqa, move.len,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
n_embd_k_gqa, nm,
|
||||
n_embd_k_gqa, move.len,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
@@ -10864,31 +10859,29 @@ struct llm_build_context {
|
||||
if (flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
n_embd_v_gqa, move.len,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, nm,
|
||||
n_embd_v_gqa, move.len,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
move.len, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, i));
|
||||
ggml_row_size(kv_self.v_l[il]->type, move.src));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
nm, n_embd_v_gqa,
|
||||
move.len, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, id));
|
||||
ggml_row_size(kv_self.v_l[il]->type, move.dst));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
@@ -17351,7 +17344,7 @@ struct llm_build_context {
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
llama_ubatch dummy = {};
|
||||
dummy.equal_seqs = true;
|
||||
|
||||
@@ -17361,7 +17354,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
|
||||
|
||||
llm.init();
|
||||
|
||||
struct ggml_cgraph * result = llm.build_defrag(ids);
|
||||
struct ggml_cgraph * result = llm.build_defrag(moves);
|
||||
|
||||
llm.free();
|
||||
|
||||
@@ -18377,7 +18370,12 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
llama_kv_cache_defrag(kv_self);
|
||||
llama_kv_cache_update(&lctx);
|
||||
slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
}
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
@@ -18782,8 +18780,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
// groups of cells moved
|
||||
std::vector<struct llama_kv_defrag_move> moves;
|
||||
|
||||
// each move requires 6*n_layer tensors (see build_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
@@ -18847,19 +18845,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
// should we stop searching for the next move?
|
||||
bool stop = false;
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
if (n_moves == max_moves) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
@@ -18875,8 +18865,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
kv_self.head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
n_moves++;
|
||||
moves.push_back({i1, i0 + nf, 1});
|
||||
cont = true;
|
||||
} else {
|
||||
moves.back().len++;
|
||||
}
|
||||
|
||||
nf++;
|
||||
@@ -18886,22 +18878,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (stop || n_moves == max_moves) {
|
||||
break;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
if (moves.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
|
||||
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@@ -18976,11 +18962,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
#else
|
||||
// ggml_graph defrag
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
|
||||
std::vector<struct llama_kv_defrag_move> chunk;
|
||||
auto end = std::min(i + max_moves, moves.size());
|
||||
chunk.assign(moves.begin() + i, moves.begin() + end);
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
|
||||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
|
||||
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
|
||||
|
||||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
}
|
||||
#endif
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Fri, 13 Dec 2024 16:11:59 -0800
|
||||
Subject: [PATCH] llama: Ensure KV cache is fully defragmented.
|
||||
|
||||
Sometimes the KV cache requires defragmentation even without
|
||||
triggering the threshold heuristic. In this case, decoding
|
||||
will not being able to find a KV cache slot. This is particularly
|
||||
difficult for the caller to handle if it happens in between
|
||||
ubatches. To avoid this, we should immediately trigger a defrag.
|
||||
|
||||
In addition, a heavily fragmented cache can require more than
|
||||
max_moves to defragment. Currently, we stop when we hit the limit
|
||||
but this can leave a cache that still does not have adequate space
|
||||
even after defragmentation is triggered. Instead, we should do
|
||||
multiple batches of processing until everything is complete.
|
||||
---
|
||||
src/llama.cpp | 99 ++++++++++++++++++++++++---------------------------
|
||||
1 file changed, 46 insertions(+), 53 deletions(-)
|
||||
|
||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||
index 4778a9ed..654e32bc 100644
|
||||
--- a/src/llama.cpp
|
||||
+++ b/src/llama.cpp
|
||||
@@ -3025,6 +3025,13 @@ struct llama_kv_cache {
|
||||
}
|
||||
};
|
||||
|
||||
+// block of KV slots to move when defragging
|
||||
+struct llama_kv_defrag_move {
|
||||
+ uint32_t src;
|
||||
+ uint32_t dst;
|
||||
+ uint32_t len;
|
||||
+};
|
||||
+
|
||||
struct llama_control_vector {
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
@@ -10802,35 +10809,23 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
- struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||
+ struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
- for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
- const uint32_t id = ids[i];
|
||||
-
|
||||
- if (i == id || id == ids.size()) {
|
||||
- continue;
|
||||
- }
|
||||
-
|
||||
- uint32_t nm = 1;
|
||||
-
|
||||
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
- nm++;
|
||||
- }
|
||||
-
|
||||
+ for (const auto & move : moves) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
- n_embd_k_gqa, nm,
|
||||
+ n_embd_k_gqa, move.len,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
- ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
|
||||
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
- n_embd_k_gqa, nm,
|
||||
+ n_embd_k_gqa, move.len,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
- ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
||||
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
@@ -10838,31 +10833,29 @@ struct llm_build_context {
|
||||
if (flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- n_embd_v_gqa, nm,
|
||||
+ n_embd_v_gqa, move.len,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- n_embd_v_gqa, nm,
|
||||
+ n_embd_v_gqa, move.len,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- nm, n_embd_v_gqa,
|
||||
+ move.len, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, i));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, move.src));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- nm, n_embd_v_gqa,
|
||||
+ move.len, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, id));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, move.dst));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||
}
|
||||
-
|
||||
- i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
@@ -17325,7 +17318,7 @@ struct llm_build_context {
|
||||
}
|
||||
};
|
||||
|
||||
-static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
+static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
llama_ubatch dummy = {};
|
||||
dummy.equal_seqs = true;
|
||||
|
||||
@@ -17335,7 +17328,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
|
||||
|
||||
llm.init();
|
||||
|
||||
- struct ggml_cgraph * result = llm.build_defrag(ids);
|
||||
+ struct ggml_cgraph * result = llm.build_defrag(moves);
|
||||
|
||||
llm.free();
|
||||
|
||||
@@ -18351,7 +18344,12 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
- const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
+ auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
+ if (!slot) {
|
||||
+ llama_kv_cache_defrag(kv_self);
|
||||
+ llama_kv_cache_update(&lctx);
|
||||
+ slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
+ }
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
@@ -18756,8 +18754,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
- // number of cells moved
|
||||
- uint32_t n_moves = 0;
|
||||
+ // groups of cells moved
|
||||
+ std::vector<struct llama_kv_defrag_move> moves;
|
||||
|
||||
// each move requires 6*n_layer tensors (see build_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
@@ -18821,19 +18819,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
- // should we stop searching for the next move?
|
||||
- bool stop = false;
|
||||
-
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
- if (n_moves == max_moves) {
|
||||
- stop = true;
|
||||
- break;
|
||||
- }
|
||||
-
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
@@ -18849,8 +18839,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
kv_self.head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
- n_moves++;
|
||||
+ moves.push_back({i1, i0 + nf, 1});
|
||||
cont = true;
|
||||
+ } else {
|
||||
+ moves.back().len++;
|
||||
}
|
||||
|
||||
nf++;
|
||||
@@ -18860,22 +18852,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
- if (stop || n_moves == max_moves) {
|
||||
- break;
|
||||
- }
|
||||
-
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
||||
- if (n_moves == 0) {
|
||||
+ if (moves.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
- //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
-
|
||||
- //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
||||
+ //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@@ -18950,11 +18936,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
#else
|
||||
// ggml_graph defrag
|
||||
|
||||
- ggml_backend_sched_reset(lctx.sched.get());
|
||||
+ for (std::size_t i = 0; i < moves.size(); i += max_moves) {
|
||||
+ std::vector<struct llama_kv_defrag_move> chunk;
|
||||
+ auto end = std::min(i + max_moves, moves.size());
|
||||
+ chunk.assign(moves.begin() + i, moves.begin() + end);
|
||||
|
||||
- ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
|
||||
+ ggml_backend_sched_reset(lctx.sched.get());
|
||||
+
|
||||
+ //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
|
||||
+ ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
|
||||
|
||||
- llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
+ }
|
||||
#endif
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
@@ -433,14 +433,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
if errors.Is(err, llama.ErrKvCacheFull) {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
s.cache.lc.KvCacheDefrag()
|
||||
err = s.lc.Decode(batch)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
if crossAttention {
|
||||
|
||||
@@ -23,13 +23,16 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/registry"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -984,8 +987,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized: access denied")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
for range 2 {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
@@ -1023,13 +1024,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
var re registry.Errs
|
||||
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
|
||||
if re.HasCode(registry.ErrCodeAnonymous) {
|
||||
// if the error is due to anonymous access return a custom error
|
||||
// this error is used by the CLI to direct a user to add their key to an account
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, re
|
||||
}
|
||||
return nil, errtypes.UnknownOllamaKey{
|
||||
Key: pubKey,
|
||||
}
|
||||
}
|
||||
return nil, re
|
||||
}
|
||||
|
||||
// Fallback to returning the raw response if parsing fails
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errUnauthorized
|
||||
// should never be reached
|
||||
return nil, fmt.Errorf("failed to make upload request")
|
||||
}
|
||||
|
||||
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||
|
||||
@@ -16,6 +16,6 @@ type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
func (e UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
|
||||
37
types/registry/error.go
Normal file
37
types/registry/error.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
|
||||
|
||||
type Err struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Errs represents the structure of error responses from the registry
|
||||
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
|
||||
type Errs struct {
|
||||
Errors []Err `json:"errors"`
|
||||
}
|
||||
|
||||
func (e Errs) Error() string {
|
||||
if len(e.Errors) == 0 {
|
||||
return "unknown registry error"
|
||||
}
|
||||
var msgs []string
|
||||
for _, err := range e.Errors {
|
||||
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
|
||||
}
|
||||
return strings.Join(msgs, "; ")
|
||||
}
|
||||
|
||||
func (e Errs) HasCode(code string) bool {
|
||||
return slices.ContainsFunc(e.Errors, func(err Err) bool {
|
||||
return err.Code == code
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user