Compare commits
1 Commits
brucemacd/
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04314765f2 |
77
cmd/cmd.go
77
cmd/cmd.go
@@ -8,7 +8,6 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
@@ -18,11 +17,9 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -33,13 +30,11 @@ import (
|
||||
"github.com/containerd/console"
|
||||
"github.com/mattn/go-runewidth"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
@@ -47,7 +42,6 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -522,64 +516,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
// unknownKey handles key validation when a connection fails due to an unknown key.
|
||||
// It attempts to open the browser for interactive sessions to let users connect their key,
|
||||
// falling back to command-line instructions for non-interactive sessions.
|
||||
// Returns nil if browser flow succeeds, or an error with connection instructions otherwise.
|
||||
func unknownKey(unknownKeyErr error) error {
|
||||
// find SSH public key in the error message
|
||||
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
|
||||
sshKeyPattern := `ssh-\w+ [^\s"]+`
|
||||
re := regexp.MustCompile(sshKeyPattern)
|
||||
matches := re.FindStringSubmatch(unknownKeyErr.Error())
|
||||
|
||||
if len(matches) > 0 {
|
||||
serverPubKey := matches[0]
|
||||
|
||||
localPubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
|
||||
// try the ollama service public key
|
||||
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
|
||||
if err != nil {
|
||||
return unknownKeyErr
|
||||
}
|
||||
localPubKey = strings.TrimSpace(string(svcPubKey))
|
||||
}
|
||||
|
||||
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
|
||||
if serverPubKey != localPubKey {
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
if term.IsTerminal(int(os.Stdout.Fd())) && !envconfig.Noninteractive() {
|
||||
// URL encode the key and device name for the browser URL
|
||||
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(localPubKey))
|
||||
d, _ := os.Hostname()
|
||||
encodedDevice := url.QueryEscape(d)
|
||||
browserURL := fmt.Sprintf("https://ollama.com/connect?host=%s&key=%s", encodedDevice, encodedKey)
|
||||
if err := browser.OpenURL(browserURL); err == nil {
|
||||
fmt.Println("Opening browser to connect your device...")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var msg strings.Builder
|
||||
msg.WriteString(unknownKeyErr.Error())
|
||||
msg.WriteString("\n\nYour ollama key is:\n")
|
||||
msg.WriteString(localPubKey)
|
||||
msg.WriteString("\nAdd your key at:\n")
|
||||
msg.WriteString("https://ollama.com/settings/keys")
|
||||
|
||||
return errors.New(msg.String())
|
||||
}
|
||||
|
||||
return unknownKeyErr
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -628,19 +564,10 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
if p != nil {
|
||||
p.Stop()
|
||||
}
|
||||
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
return unknownKey(err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
@@ -651,7 +578,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
destination := n.String()
|
||||
if isOllamaHost {
|
||||
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
|
||||
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||
}
|
||||
fmt.Printf("\nYou can find your model at:\n\n")
|
||||
@@ -1547,8 +1474,6 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_GPU_OVERHEAD"],
|
||||
envVars["OLLAMA_LOAD_TIMEOUT"],
|
||||
})
|
||||
case pushCmd:
|
||||
appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_NONINTERACTIVE"]})
|
||||
default:
|
||||
appendEnvDocs(cmd, envs)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
@@ -369,13 +368,15 @@ func TestGetModelfileName(t *testing.T) {
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
modelName: "successful-push",
|
||||
name: "successful push",
|
||||
modelName: "test-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
@@ -388,8 +389,8 @@ func TestPushHandler(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "successful-push" {
|
||||
t.Errorf("expected model name 'successful-push', got %s", req.Name)
|
||||
if req.Name != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
}
|
||||
|
||||
// Simulate progress updates
|
||||
@@ -408,10 +409,11 @@ func TestPushHandler(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
|
||||
},
|
||||
{
|
||||
modelName: "unauthorized-push",
|
||||
name: "unauthorized push",
|
||||
modelName: "unauthorized-model",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -426,29 +428,10 @@ func TestPushHandler(t *testing.T) {
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
{
|
||||
modelName: "unknown-key-err",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
uerr := errtypes.UnknownOllamaKey{
|
||||
Key: "aaa",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": uerr.Error(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "unauthorized: unknown ollama key \"aaa\"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelName, func(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||
handler(w, r)
|
||||
|
||||
@@ -165,9 +165,6 @@ var (
|
||||
IntelGPU = Bool("OLLAMA_INTEL_GPU")
|
||||
// MultiUserCache optimizes prompt caching for multi-user scenarios
|
||||
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
|
||||
// Noninteractive is true when CLI interactive features should be disabled.
|
||||
// This affects features like automatic browser opening.
|
||||
Noninteractive = Bool("OLLAMA_NONINTERACTIVE")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
@@ -253,7 +250,6 @@ func AsMap() map[string]EnvVar {
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_NONINTERACTIVE": {"OLLAMA_NONINTERACTIVE", Noninteractive(), "Disable interactive CLI features, such as automatically opening the browser"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
||||
1
go.mod
1
go.mod
@@ -36,7 +36,6 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
||||
3
go.sum
3
go.sum
@@ -159,8 +159,6 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
|
||||
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -283,7 +281,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
|
||||
|
||||
99
llama/llama.cpp
vendored
99
llama/llama.cpp
vendored
@@ -3051,13 +3051,6 @@ struct llama_kv_cache {
|
||||
}
|
||||
};
|
||||
|
||||
// block of KV slots to move when defragging
|
||||
struct llama_kv_defrag_move {
|
||||
uint32_t src;
|
||||
uint32_t dst;
|
||||
uint32_t len;
|
||||
};
|
||||
|
||||
struct llama_control_vector {
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
@@ -10835,23 +10828,35 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
for (const auto & move : moves) {
|
||||
for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == ids.size()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
n_embd_k_gqa, move.len,
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
n_embd_k_gqa, move.len,
|
||||
n_embd_k_gqa, nm,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
@@ -10859,29 +10864,31 @@ struct llm_build_context {
|
||||
if (flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, move.len,
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
n_embd_v_gqa, move.len,
|
||||
n_embd_v_gqa, nm,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
move.len, n_embd_v_gqa,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, move.src));
|
||||
ggml_row_size(kv_self.v_l[il]->type, i));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
move.len, n_embd_v_gqa,
|
||||
nm, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
ggml_row_size(kv_self.v_l[il]->type, move.dst));
|
||||
ggml_row_size(kv_self.v_l[il]->type, id));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
@@ -17344,7 +17351,7 @@ struct llm_build_context {
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
llama_ubatch dummy = {};
|
||||
dummy.equal_seqs = true;
|
||||
|
||||
@@ -17354,7 +17361,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
|
||||
|
||||
llm.init();
|
||||
|
||||
struct ggml_cgraph * result = llm.build_defrag(moves);
|
||||
struct ggml_cgraph * result = llm.build_defrag(ids);
|
||||
|
||||
llm.free();
|
||||
|
||||
@@ -18370,12 +18377,7 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
llama_kv_cache_defrag(kv_self);
|
||||
llama_kv_cache_update(&lctx);
|
||||
slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
}
|
||||
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
@@ -18780,8 +18782,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
// groups of cells moved
|
||||
std::vector<struct llama_kv_defrag_move> moves;
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
|
||||
// each move requires 6*n_layer tensors (see build_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
@@ -18845,11 +18847,19 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
// should we stop searching for the next move?
|
||||
bool stop = false;
|
||||
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
if (n_moves == max_moves) {
|
||||
stop = true;
|
||||
break;
|
||||
}
|
||||
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
@@ -18865,10 +18875,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
kv_self.head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
moves.push_back({i1, i0 + nf, 1});
|
||||
n_moves++;
|
||||
cont = true;
|
||||
} else {
|
||||
moves.back().len++;
|
||||
}
|
||||
|
||||
nf++;
|
||||
@@ -18878,16 +18886,22 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (stop || n_moves == max_moves) {
|
||||
break;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
||||
if (moves.size() == 0) {
|
||||
if (n_moves == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
|
||||
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@@ -18962,18 +18976,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
#else
|
||||
// ggml_graph defrag
|
||||
|
||||
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
|
||||
std::vector<struct llama_kv_defrag_move> chunk;
|
||||
auto end = std::min(i + max_moves, moves.size());
|
||||
chunk.assign(moves.begin() + i, moves.begin() + end);
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
|
||||
|
||||
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
|
||||
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
|
||||
|
||||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
}
|
||||
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
#endif
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Fri, 13 Dec 2024 16:11:59 -0800
|
||||
Subject: [PATCH] llama: Ensure KV cache is fully defragmented.
|
||||
|
||||
Sometimes the KV cache requires defragmentation even without
|
||||
triggering the threshold heuristic. In this case, decoding
|
||||
will not being able to find a KV cache slot. This is particularly
|
||||
difficult for the caller to handle if it happens in between
|
||||
ubatches. To avoid this, we should immediately trigger a defrag.
|
||||
|
||||
In addition, a heavily fragmented cache can require more than
|
||||
max_moves to defragment. Currently, we stop when we hit the limit
|
||||
but this can leave a cache that still does not have adequate space
|
||||
even after defragmentation is triggered. Instead, we should do
|
||||
multiple batches of processing until everything is complete.
|
||||
---
|
||||
src/llama.cpp | 99 ++++++++++++++++++++++++---------------------------
|
||||
1 file changed, 46 insertions(+), 53 deletions(-)
|
||||
|
||||
diff --git a/src/llama.cpp b/src/llama.cpp
|
||||
index 4778a9ed..654e32bc 100644
|
||||
--- a/src/llama.cpp
|
||||
+++ b/src/llama.cpp
|
||||
@@ -3025,6 +3025,13 @@ struct llama_kv_cache {
|
||||
}
|
||||
};
|
||||
|
||||
+// block of KV slots to move when defragging
|
||||
+struct llama_kv_defrag_move {
|
||||
+ uint32_t src;
|
||||
+ uint32_t dst;
|
||||
+ uint32_t len;
|
||||
+};
|
||||
+
|
||||
struct llama_control_vector {
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
@@ -10802,35 +10809,23 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
- struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
|
||||
+ struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
- for (uint32_t i = 0; i < ids.size(); ++i) {
|
||||
- const uint32_t id = ids[i];
|
||||
-
|
||||
- if (i == id || id == ids.size()) {
|
||||
- continue;
|
||||
- }
|
||||
-
|
||||
- uint32_t nm = 1;
|
||||
-
|
||||
- while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
||||
- nm++;
|
||||
- }
|
||||
-
|
||||
+ for (const auto & move : moves) {
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
- n_embd_k_gqa, nm,
|
||||
+ n_embd_k_gqa, move.len,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
- ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i));
|
||||
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
|
||||
|
||||
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
|
||||
- n_embd_k_gqa, nm,
|
||||
+ n_embd_k_gqa, move.len,
|
||||
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||
- ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id));
|
||||
+ ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
|
||||
|
||||
ggml_tensor * view_v_src;
|
||||
ggml_tensor * view_v_dst;
|
||||
@@ -10838,31 +10833,29 @@ struct llm_build_context {
|
||||
if (flash_attn) {
|
||||
// NOTE: the V cache is not transposed when using flash attention
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- n_embd_v_gqa, nm,
|
||||
+ n_embd_v_gqa, move.len,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- n_embd_v_gqa, nm,
|
||||
+ n_embd_v_gqa, move.len,
|
||||
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
|
||||
} else {
|
||||
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- nm, n_embd_v_gqa,
|
||||
+ move.len, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, i));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, move.src));
|
||||
|
||||
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
|
||||
- nm, n_embd_v_gqa,
|
||||
+ move.len, n_embd_v_gqa,
|
||||
ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
|
||||
- ggml_row_size(kv_self.v_l[il]->type, id));
|
||||
+ ggml_row_size(kv_self.v_l[il]->type, move.dst));
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
||||
}
|
||||
-
|
||||
- i += nm - 1;
|
||||
}
|
||||
|
||||
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
||||
@@ -17325,7 +17318,7 @@ struct llm_build_context {
|
||||
}
|
||||
};
|
||||
|
||||
-static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
+static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
|
||||
llama_ubatch dummy = {};
|
||||
dummy.equal_seqs = true;
|
||||
|
||||
@@ -17335,7 +17328,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
|
||||
|
||||
llm.init();
|
||||
|
||||
- struct ggml_cgraph * result = llm.build_defrag(ids);
|
||||
+ struct ggml_cgraph * result = llm.build_defrag(moves);
|
||||
|
||||
llm.free();
|
||||
|
||||
@@ -18351,7 +18344,12 @@ static int llama_decode_internal(
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
- const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
+ auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
+ if (!slot) {
|
||||
+ llama_kv_cache_defrag(kv_self);
|
||||
+ llama_kv_cache_update(&lctx);
|
||||
+ slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||
+ }
|
||||
if (!slot) {
|
||||
return 1;
|
||||
}
|
||||
@@ -18756,8 +18754,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
- // number of cells moved
|
||||
- uint32_t n_moves = 0;
|
||||
+ // groups of cells moved
|
||||
+ std::vector<struct llama_kv_defrag_move> moves;
|
||||
|
||||
// each move requires 6*n_layer tensors (see build_defrag)
|
||||
// - source view, destination view, copy operation
|
||||
@@ -18821,19 +18819,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
// are we moving a continuous block of memory?
|
||||
bool cont = false;
|
||||
|
||||
- // should we stop searching for the next move?
|
||||
- bool stop = false;
|
||||
-
|
||||
// go back and move the nf cells to the hole
|
||||
for (; i1 < n_kv; ++i1) {
|
||||
auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||
- if (n_moves == max_moves) {
|
||||
- stop = true;
|
||||
- break;
|
||||
- }
|
||||
-
|
||||
cont = false;
|
||||
continue;
|
||||
}
|
||||
@@ -18849,8 +18839,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
kv_self.head = n_used;
|
||||
|
||||
if (!cont) {
|
||||
- n_moves++;
|
||||
+ moves.push_back({i1, i0 + nf, 1});
|
||||
cont = true;
|
||||
+ } else {
|
||||
+ moves.back().len++;
|
||||
}
|
||||
|
||||
nf++;
|
||||
@@ -18860,22 +18852,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
- if (stop || n_moves == max_moves) {
|
||||
- break;
|
||||
- }
|
||||
-
|
||||
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
|
||||
|
||||
i0 += nh - 1;
|
||||
}
|
||||
|
||||
- if (n_moves == 0) {
|
||||
+ if (moves.size() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
- //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
-
|
||||
- //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
|
||||
+ //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
|
||||
|
||||
#if 0
|
||||
// CPU defrag
|
||||
@@ -18950,11 +18936,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
#else
|
||||
// ggml_graph defrag
|
||||
|
||||
- ggml_backend_sched_reset(lctx.sched.get());
|
||||
+ for (std::size_t i = 0; i < moves.size(); i += max_moves) {
|
||||
+ std::vector<struct llama_kv_defrag_move> chunk;
|
||||
+ auto end = std::min(i + max_moves, moves.size());
|
||||
+ chunk.assign(moves.begin() + i, moves.begin() + end);
|
||||
|
||||
- ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
|
||||
+ ggml_backend_sched_reset(lctx.sched.get());
|
||||
+
|
||||
+ //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
|
||||
+ ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
|
||||
|
||||
- llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
|
||||
+ }
|
||||
#endif
|
||||
|
||||
//const int64_t t_end = ggml_time_us();
|
||||
@@ -433,7 +433,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
if errors.Is(err, llama.ErrKvCacheFull) {
|
||||
slog.Debug("defragmenting kv cache")
|
||||
s.cache.lc.KvCacheDefrag()
|
||||
err = s.lc.Decode(batch)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if crossAttention {
|
||||
|
||||
@@ -700,24 +700,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
// Field was set, but "missing" a value. We accept
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
switch {
|
||||
case bytes.Equal(req.Format, []byte(`""`)) || bytes.Equal(req.Format, []byte(`null`)):
|
||||
// fallthrough
|
||||
case bytes.Equal(req.Format, []byte(`"json"`)):
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
|
||||
case bytes.HasPrefix(req.Format, []byte("{")):
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
default:
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,34 +39,25 @@ func TestLLMServerCompletionFormat(t *testing.T) {
|
||||
|
||||
cancel() // prevent further processing if request makes it past the format check
|
||||
|
||||
checkValid := func(err error) {
|
||||
checkCanceled := func(err error) {
|
||||
t.Helper()
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("Completion: err = %v; expected context.Canceled", err)
|
||||
}
|
||||
}
|
||||
|
||||
valids := []string{
|
||||
// "missing"
|
||||
``,
|
||||
`""`,
|
||||
`null`,
|
||||
|
||||
// JSON
|
||||
`"json"`,
|
||||
`{"type":"object"}`,
|
||||
}
|
||||
valids := []string{`"json"`, `{"type":"object"}`, ``, `""`, `null`}
|
||||
for _, valid := range valids {
|
||||
err := s.Completion(ctx, CompletionRequest{
|
||||
Options: new(api.Options),
|
||||
Format: []byte(valid),
|
||||
}, nil)
|
||||
checkValid(err)
|
||||
checkCanceled(err)
|
||||
}
|
||||
|
||||
err := s.Completion(ctx, CompletionRequest{
|
||||
Options: new(api.Options),
|
||||
Format: nil, // missing format
|
||||
}, nil)
|
||||
checkValid(err)
|
||||
checkCanceled(err)
|
||||
}
|
||||
|
||||
@@ -23,16 +23,13 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/registry"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -987,6 +984,8 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized: access denied")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
for range 2 {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
@@ -1024,33 +1023,13 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
var re registry.Errs
|
||||
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
|
||||
if re.HasCode(registry.ErrCodeAnonymous) {
|
||||
// if the error is due to anonymous access return a custom error
|
||||
// this error is used by the CLI to direct a user to add their key to an account
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, re
|
||||
}
|
||||
return nil, errtypes.UnknownOllamaKey{
|
||||
Key: pubKey,
|
||||
}
|
||||
}
|
||||
return nil, re
|
||||
}
|
||||
|
||||
// Fallback to returning the raw response if parsing fails
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// should never be reached
|
||||
return nil, fmt.Errorf("failed to make upload request")
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||
|
||||
@@ -16,6 +16,6 @@ type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e UnknownOllamaKey) Error() string {
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
|
||||
|
||||
type Err struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Errs represents the structure of error responses from the registry
|
||||
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
|
||||
type Errs struct {
|
||||
Errors []Err `json:"errors"`
|
||||
}
|
||||
|
||||
func (e Errs) Error() string {
|
||||
if len(e.Errors) == 0 {
|
||||
return "unknown registry error"
|
||||
}
|
||||
var msgs []string
|
||||
for _, err := range e.Errors {
|
||||
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
|
||||
}
|
||||
return strings.Join(msgs, "; ")
|
||||
}
|
||||
|
||||
func (e Errs) HasCode(code string) bool {
|
||||
return slices.ContainsFunc(e.Errors, func(err Err) bool {
|
||||
return err.Code == code
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user