diff --git a/fs/ggml/gguf.go b/fs/ggml/gguf.go index b694deadb..9be410862 100644 --- a/fs/ggml/gguf.go +++ b/fs/ggml/gguf.go @@ -582,7 +582,8 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error { if !strings.HasPrefix(k, arch+".") && !strings.HasPrefix(k, "general.") && !strings.HasPrefix(k, "adapter.") && - !strings.HasPrefix(k, "tokenizer.") { + !strings.HasPrefix(k, "tokenizer.") && + !strings.HasPrefix(k, "split.") { k = arch + "." + k } @@ -597,6 +598,8 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error { var err error switch v := v.(type) { + case uint16: + err = writeGGUF(ws, ggufTypeUint16, v) case uint32, FileType: err = writeGGUF(ws, ggufTypeUint32, v) case uint64: diff --git a/server/create.go b/server/create.go index 30cb8b9c9..02badbfba 100644 --- a/server/create.go +++ b/server/create.go @@ -39,6 +39,8 @@ var ( errUnknownType = errors.New("unknown type") errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified") errFilePath = errors.New("file path must be relative") + errIncompleteShardedGGUF = errors.New("missing some GGUF splits") + errExtraShardedGGUF = errors.New("extra GGUF splits found") ) func broadcastKV(main *ggml.GGML, subs ...*ggml.GGML) { @@ -58,6 +60,76 @@ func broadcastKV(main *ggml.GGML, subs ...*ggml.GGML) { } } +func baseLayerSortNCheckSan(baseLayers *[]*layerGGML) error { + slices.SortStableFunc(*baseLayers, func(a, b *layerGGML) int { + var aScore, bScore int + if a.GGML == nil { + // chat template and parameter can be added here. use very big number to move them at last + aScore = 0x7fffffff + } else { + aSplit := a.GGML.KV().GGUFSplitInfo() + if aSplit == nil { + aScore = -1 + } else { + aScore = int(aSplit.No) + } + } + if b.GGML == nil { + bScore = 0x7fffffff + } else { + bSplit := b.GGML.KV().GGUFSplitInfo() + if bSplit == nil { + bScore = -1 + } else { + bScore = int(bSplit.No) + } + } + return cmp.Compare(aScore, bScore) + }) + // sanity check for layers + { + ggmlPtrs := make([]*ggml.GGML, 0, len(*baseLayers)) + firstSplitCount := -1 + foundSplitNos := make([]uint16, 0) + for i, layer := range *baseLayers { + if i == 0 { + if layer.GGML == nil { + // First item should be GGUF after sorting + return errNoFilesProvided + } + } + if layer.GGML != nil && layer.GGML.KV().GGUFSplitInfo() != nil { + if firstSplitCount == -1 { + if layer.GGML.KV().GGUFSplitInfo().No != 0 { + return errIncompleteShardedGGUF + } + firstSplitCount = int(layer.GGML.KV().GGUFSplitInfo().Count) + foundSplitNos = append(foundSplitNos, layer.KV().GGUFSplitInfo().No) + } else if firstSplitCount != int(layer.KV().GGUFSplitInfo().Count) { + return errExtraShardedGGUF + } else { + if foundSplitNos[len(foundSplitNos)-1] == layer.KV().GGUFSplitInfo().No { + return errExtraShardedGGUF + } else if foundSplitNos[len(foundSplitNos)-1] != layer.KV().GGUFSplitInfo().No-1 { + return errIncompleteShardedGGUF + } else { + foundSplitNos = append(foundSplitNos, layer.KV().GGUFSplitInfo().No) + } + } + // only gguf splits should be included + ggmlPtrs = append(ggmlPtrs, layer.GGML) + } + } + if firstSplitCount != -1 && len(foundSplitNos) != firstSplitCount { + return errIncompleteShardedGGUF + } + if len(ggmlPtrs) > 1 { + broadcastKV(ggmlPtrs[0], ggmlPtrs[1:]...) + } + } + return nil +} + func (s *Server) CreateHandler(c *gin.Context) { config := &ConfigV2{ OS: "linux", @@ -175,45 +247,11 @@ func (s *Server) CreateHandler(c *gin.Context) { return } // Sort baseLayers here to ensure that split model will be correctly ordered - splitsFoundWhileSorting := false - slices.SortStableFunc(baseLayers, func(a, b *layerGGML) int { - var aScore, bScore int - if a.GGML == nil { - // chat template and parameter can be added here. use very big number to move them at last - aScore = 0x7fffffff - } else { - aSplit := a.GGML.KV().GGUFSplitInfo() - if aSplit == nil { - aScore = -1 - } else { - aScore = int(aSplit.No) - } - } - if b.GGML == nil { - bScore = 0x7fffffff - } else { - bSplit := b.GGML.KV().GGUFSplitInfo() - if bSplit == nil { - bScore = -1 - } else { - bScore = int(bSplit.No) - } - } - if aScore > -1 && aScore < 0x7fffffff && bScore > -1 && bScore < 0x7fffffff { - splitsFoundWhileSorting = true - } - return cmp.Compare(aScore, bScore) - }) - if splitsFoundWhileSorting { - ggmlPtrs := make([]*ggml.GGML, 0, len(baseLayers)) - for _, layer := range baseLayers { - if layer.GGML != nil && layer.GGML.KV().GGUFSplitInfo() != nil { - // only gguf splits should be included - ggmlPtrs = append(ggmlPtrs, layer.GGML) - } - } - if len(ggmlPtrs) > 1 { - broadcastKV(ggmlPtrs[0], ggmlPtrs[1:]...) + if !remote { + err := baseLayerSortNCheckSan(&baseLayers) + if err != nil { + ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest} + return } } diff --git a/server/routes_create_test.go b/server/routes_create_test.go index 909ebfe53..78d228ba4 100644 --- a/server/routes_create_test.go +++ b/server/routes_create_test.go @@ -954,3 +954,236 @@ func TestDetectModelTypeFromFiles(t *testing.T) { } }) } + +func TestShardedGGUF(t *testing.T) { + gin.SetMode(gin.TestMode) + p := t.TempDir() + t.Setenv("OLLAMA_MODELS", p) + + _, fullDigest := createBinFile(t, ggml.KV{}, []*ggml.Tensor{}) + _, splitDigest1 := createBinFile(t, ggml.KV{ + "split.no": uint16(0), + "split.count": uint16(3), + }, []*ggml.Tensor{}) + _, splitDigest2 := createBinFile(t, ggml.KV{ + "split.no": uint16(1), + "split.count": uint16(3), + }, []*ggml.Tensor{}) + _, splitDigest3 := createBinFile(t, ggml.KV{ + "split.no": uint16(2), + "split.count": uint16(3), + }, []*ggml.Tensor{}) + _, splitDigest4 := createBinFile(t, ggml.KV{ + "split.no": uint16(0), + "split.count": uint16(4), + }, []*ggml.Tensor{}) + _, splitDigest5 := createBinFile(t, ggml.KV{ + "general.architecture": "test1", + "split.no": uint16(1), + "split.count": uint16(3), + }, []*ggml.Tensor{}) + + var s Server + + t.Run("single full gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-single-full", + Files: map[string]string{"test.gguf": fullDigest}, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + fmt.Println(w) + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + manifest, err := ParseNamedManifest(model.ParseName("test-single-full")) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + for i, layer := range manifest.Layers { + if i != 0 { + t.Fatalf("expect 1 layer, actually found layer with index %d", i) + } else if layer.Digest != fullDigest { + t.Fatalf("expect digest %s, actual %s", fullDigest, layer.Digest) + } + } + }) + + t.Run("complete split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-complete-split", + Files: map[string]string{ + "test-00001-of-00003.gguf": splitDigest1, + "test-00002-of-00003.gguf": splitDigest2, + "test-00003-of-00003.gguf": splitDigest3, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + fmt.Println(w) + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + correctOrder := []string{ + splitDigest1, splitDigest2, splitDigest3, + } + + manifest, err := ParseNamedManifest(model.ParseName("test-complete-split")) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + for i, layer := range manifest.Layers { + if i >= 3 { + t.Fatalf("expect 3 layers, actually found layer with index %d", i) + } else if layer.Digest != correctOrder[i] { + t.Fatalf("expect digest %s, actual %s", correctOrder[i], layer.Digest) + } + } + }) + + t.Run("complete split misordered gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-complete-split-misorder", + Files: map[string]string{ + "test-00003-of-00003.gguf": splitDigest3, + "test-00001-of-00003.gguf": splitDigest1, + "test-00002-of-00003.gguf": splitDigest2, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + fmt.Println(w) + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + correctOrder := []string{ + splitDigest1, splitDigest2, splitDigest3, + } + + manifest, err := ParseNamedManifest(model.ParseName("test-complete-split-misorder")) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + for i, layer := range manifest.Layers { + if i >= 3 { + t.Fatalf("expect 3 layers, actually found layer with index %d", i) + } else if layer.Digest != correctOrder[i] { + t.Fatalf("expect digest %s, actual %s", correctOrder[i], layer.Digest) + } + } + }) + + t.Run("mixed full and split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-full-split-mixing", + Files: map[string]string{ + "test-00002-of-00003.gguf": splitDigest2, + "test-00003-of-00003.gguf": splitDigest3, + "test1.gguf": fullDigest, + "test-00001-of-00003.gguf": splitDigest1, + }, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + fmt.Println(w) + t.Fatalf("expected status code 200, actual %d", w.Code) + } + + correctOrder := []string{ + fullDigest, splitDigest1, splitDigest2, splitDigest3, + } + + manifest, err := ParseNamedManifest(model.ParseName("test-full-split-mixing")) + if err != nil { + t.Fatalf("parse manifest: %v", err) + } + for i, layer := range manifest.Layers { + if i >= 4 { + t.Fatalf("expect 4 layers, actually found layer with index %d", i) + } else if layer.Digest != correctOrder[i] { + t.Fatalf("expect digest %s, actual %s", correctOrder[i], layer.Digest) + } + } + }) + + t.Run("mixed wrong split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-extra-split", + Files: map[string]string{ + "test-00002-of-00003.gguf": splitDigest2, + "test-00003-of-00003.gguf": splitDigest3, + "test-00001-of-00003.gguf": splitDigest1, + "test1-00001-of-00004.gguf": splitDigest4, + }, + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + + t.Run("mixed same count wrong split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-extra-split", + Files: map[string]string{ + "test-00002-of-00003.gguf": splitDigest2, + "test-00003-of-00003.gguf": splitDigest3, + "test-00001-of-00003.gguf": splitDigest1, + "test1-00002-of-00003.gguf": splitDigest5, + }, + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + t.Run("missing head split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-extra-split", + Files: map[string]string{ + "test-00002-of-00003.gguf": splitDigest2, + "test-00003-of-00003.gguf": splitDigest3, + }, + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + t.Run("missing mid split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-extra-split", + Files: map[string]string{ + "test-00001-of-00003.gguf": splitDigest1, + "test-00003-of-00003.gguf": splitDigest3, + }, + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + t.Run("missing tail split gguf", func(t *testing.T) { + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Name: "test-extra-split", + Files: map[string]string{ + "test-00001-of-00003.gguf": splitDigest1, + "test-00002-of-00003.gguf": splitDigest2, + }, + Stream: &stream, + }) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected status code 400, actual %d", w.Code) + } + }) + +}