server: sanity check when creating model with split gguf

This commit is contained in:
cvrunmin 2025-11-28 13:38:52 +08:00
parent b2ebfccff8
commit 10dc89faca
3 changed files with 314 additions and 40 deletions

View File

@ -582,7 +582,8 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
if !strings.HasPrefix(k, arch+".") &&
!strings.HasPrefix(k, "general.") &&
!strings.HasPrefix(k, "adapter.") &&
!strings.HasPrefix(k, "tokenizer.") {
!strings.HasPrefix(k, "tokenizer.") &&
!strings.HasPrefix(k, "split.") {
k = arch + "." + k
}
@ -597,6 +598,8 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
var err error
switch v := v.(type) {
case uint16:
err = writeGGUF(ws, ggufTypeUint16, v)
case uint32, FileType:
err = writeGGUF(ws, ggufTypeUint32, v)
case uint64:

View File

@ -39,6 +39,8 @@ var (
errUnknownType = errors.New("unknown type")
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
errFilePath = errors.New("file path must be relative")
errIncompleteShardedGGUF = errors.New("missing some GGUF splits")
errExtraShardedGGUF = errors.New("extra GGUF splits found")
)
func broadcastKV(main *ggml.GGML, subs ...*ggml.GGML) {
@ -58,6 +60,76 @@ func broadcastKV(main *ggml.GGML, subs ...*ggml.GGML) {
}
}
func baseLayerSortNCheckSan(baseLayers *[]*layerGGML) error {
slices.SortStableFunc(*baseLayers, func(a, b *layerGGML) int {
var aScore, bScore int
if a.GGML == nil {
// chat template and parameter can be added here. use very big number to move them at last
aScore = 0x7fffffff
} else {
aSplit := a.GGML.KV().GGUFSplitInfo()
if aSplit == nil {
aScore = -1
} else {
aScore = int(aSplit.No)
}
}
if b.GGML == nil {
bScore = 0x7fffffff
} else {
bSplit := b.GGML.KV().GGUFSplitInfo()
if bSplit == nil {
bScore = -1
} else {
bScore = int(bSplit.No)
}
}
return cmp.Compare(aScore, bScore)
})
// sanity check for layers
{
ggmlPtrs := make([]*ggml.GGML, 0, len(*baseLayers))
firstSplitCount := -1
foundSplitNos := make([]uint16, 0)
for i, layer := range *baseLayers {
if i == 0 {
if layer.GGML == nil {
// First item should be GGUF after sorting
return errNoFilesProvided
}
}
if layer.GGML != nil && layer.GGML.KV().GGUFSplitInfo() != nil {
if firstSplitCount == -1 {
if layer.GGML.KV().GGUFSplitInfo().No != 0 {
return errIncompleteShardedGGUF
}
firstSplitCount = int(layer.GGML.KV().GGUFSplitInfo().Count)
foundSplitNos = append(foundSplitNos, layer.KV().GGUFSplitInfo().No)
} else if firstSplitCount != int(layer.KV().GGUFSplitInfo().Count) {
return errExtraShardedGGUF
} else {
if foundSplitNos[len(foundSplitNos)-1] == layer.KV().GGUFSplitInfo().No {
return errExtraShardedGGUF
} else if foundSplitNos[len(foundSplitNos)-1] != layer.KV().GGUFSplitInfo().No-1 {
return errIncompleteShardedGGUF
} else {
foundSplitNos = append(foundSplitNos, layer.KV().GGUFSplitInfo().No)
}
}
// only gguf splits should be included
ggmlPtrs = append(ggmlPtrs, layer.GGML)
}
}
if firstSplitCount != -1 && len(foundSplitNos) != firstSplitCount {
return errIncompleteShardedGGUF
}
if len(ggmlPtrs) > 1 {
broadcastKV(ggmlPtrs[0], ggmlPtrs[1:]...)
}
}
return nil
}
func (s *Server) CreateHandler(c *gin.Context) {
config := &ConfigV2{
OS: "linux",
@ -175,45 +247,11 @@ func (s *Server) CreateHandler(c *gin.Context) {
return
}
// Sort baseLayers here to ensure that split model will be correctly ordered
splitsFoundWhileSorting := false
slices.SortStableFunc(baseLayers, func(a, b *layerGGML) int {
var aScore, bScore int
if a.GGML == nil {
// chat template and parameter can be added here. use very big number to move them at last
aScore = 0x7fffffff
} else {
aSplit := a.GGML.KV().GGUFSplitInfo()
if aSplit == nil {
aScore = -1
} else {
aScore = int(aSplit.No)
}
}
if b.GGML == nil {
bScore = 0x7fffffff
} else {
bSplit := b.GGML.KV().GGUFSplitInfo()
if bSplit == nil {
bScore = -1
} else {
bScore = int(bSplit.No)
}
}
if aScore > -1 && aScore < 0x7fffffff && bScore > -1 && bScore < 0x7fffffff {
splitsFoundWhileSorting = true
}
return cmp.Compare(aScore, bScore)
})
if splitsFoundWhileSorting {
ggmlPtrs := make([]*ggml.GGML, 0, len(baseLayers))
for _, layer := range baseLayers {
if layer.GGML != nil && layer.GGML.KV().GGUFSplitInfo() != nil {
// only gguf splits should be included
ggmlPtrs = append(ggmlPtrs, layer.GGML)
}
}
if len(ggmlPtrs) > 1 {
broadcastKV(ggmlPtrs[0], ggmlPtrs[1:]...)
if !remote {
err := baseLayerSortNCheckSan(&baseLayers)
if err != nil {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
}
}

View File

@ -954,3 +954,236 @@ func TestDetectModelTypeFromFiles(t *testing.T) {
}
})
}
func TestShardedGGUF(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
_, fullDigest := createBinFile(t, ggml.KV{}, []*ggml.Tensor{})
_, splitDigest1 := createBinFile(t, ggml.KV{
"split.no": uint16(0),
"split.count": uint16(3),
}, []*ggml.Tensor{})
_, splitDigest2 := createBinFile(t, ggml.KV{
"split.no": uint16(1),
"split.count": uint16(3),
}, []*ggml.Tensor{})
_, splitDigest3 := createBinFile(t, ggml.KV{
"split.no": uint16(2),
"split.count": uint16(3),
}, []*ggml.Tensor{})
_, splitDigest4 := createBinFile(t, ggml.KV{
"split.no": uint16(0),
"split.count": uint16(4),
}, []*ggml.Tensor{})
_, splitDigest5 := createBinFile(t, ggml.KV{
"general.architecture": "test1",
"split.no": uint16(1),
"split.count": uint16(3),
}, []*ggml.Tensor{})
var s Server
t.Run("single full gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-single-full",
Files: map[string]string{"test.gguf": fullDigest},
Stream: &stream,
})
if w.Code != http.StatusOK {
fmt.Println(w)
t.Fatalf("expected status code 200, actual %d", w.Code)
}
manifest, err := ParseNamedManifest(model.ParseName("test-single-full"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
for i, layer := range manifest.Layers {
if i != 0 {
t.Fatalf("expect 1 layer, actually found layer with index %d", i)
} else if layer.Digest != fullDigest {
t.Fatalf("expect digest %s, actual %s", fullDigest, layer.Digest)
}
}
})
t.Run("complete split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-complete-split",
Files: map[string]string{
"test-00001-of-00003.gguf": splitDigest1,
"test-00002-of-00003.gguf": splitDigest2,
"test-00003-of-00003.gguf": splitDigest3,
},
Stream: &stream,
})
if w.Code != http.StatusOK {
fmt.Println(w)
t.Fatalf("expected status code 200, actual %d", w.Code)
}
correctOrder := []string{
splitDigest1, splitDigest2, splitDigest3,
}
manifest, err := ParseNamedManifest(model.ParseName("test-complete-split"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
for i, layer := range manifest.Layers {
if i >= 3 {
t.Fatalf("expect 3 layers, actually found layer with index %d", i)
} else if layer.Digest != correctOrder[i] {
t.Fatalf("expect digest %s, actual %s", correctOrder[i], layer.Digest)
}
}
})
t.Run("complete split misordered gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-complete-split-misorder",
Files: map[string]string{
"test-00003-of-00003.gguf": splitDigest3,
"test-00001-of-00003.gguf": splitDigest1,
"test-00002-of-00003.gguf": splitDigest2,
},
Stream: &stream,
})
if w.Code != http.StatusOK {
fmt.Println(w)
t.Fatalf("expected status code 200, actual %d", w.Code)
}
correctOrder := []string{
splitDigest1, splitDigest2, splitDigest3,
}
manifest, err := ParseNamedManifest(model.ParseName("test-complete-split-misorder"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
for i, layer := range manifest.Layers {
if i >= 3 {
t.Fatalf("expect 3 layers, actually found layer with index %d", i)
} else if layer.Digest != correctOrder[i] {
t.Fatalf("expect digest %s, actual %s", correctOrder[i], layer.Digest)
}
}
})
t.Run("mixed full and split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-full-split-mixing",
Files: map[string]string{
"test-00002-of-00003.gguf": splitDigest2,
"test-00003-of-00003.gguf": splitDigest3,
"test1.gguf": fullDigest,
"test-00001-of-00003.gguf": splitDigest1,
},
Stream: &stream,
})
if w.Code != http.StatusOK {
fmt.Println(w)
t.Fatalf("expected status code 200, actual %d", w.Code)
}
correctOrder := []string{
fullDigest, splitDigest1, splitDigest2, splitDigest3,
}
manifest, err := ParseNamedManifest(model.ParseName("test-full-split-mixing"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
for i, layer := range manifest.Layers {
if i >= 4 {
t.Fatalf("expect 4 layers, actually found layer with index %d", i)
} else if layer.Digest != correctOrder[i] {
t.Fatalf("expect digest %s, actual %s", correctOrder[i], layer.Digest)
}
}
})
t.Run("mixed wrong split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-extra-split",
Files: map[string]string{
"test-00002-of-00003.gguf": splitDigest2,
"test-00003-of-00003.gguf": splitDigest3,
"test-00001-of-00003.gguf": splitDigest1,
"test1-00001-of-00004.gguf": splitDigest4,
},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("mixed same count wrong split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-extra-split",
Files: map[string]string{
"test-00002-of-00003.gguf": splitDigest2,
"test-00003-of-00003.gguf": splitDigest3,
"test-00001-of-00003.gguf": splitDigest1,
"test1-00002-of-00003.gguf": splitDigest5,
},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("missing head split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-extra-split",
Files: map[string]string{
"test-00002-of-00003.gguf": splitDigest2,
"test-00003-of-00003.gguf": splitDigest3,
},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("missing mid split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-extra-split",
Files: map[string]string{
"test-00001-of-00003.gguf": splitDigest1,
"test-00003-of-00003.gguf": splitDigest3,
},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("missing tail split gguf", func(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-extra-split",
Files: map[string]string{
"test-00001-of-00003.gguf": splitDigest1,
"test-00002-of-00003.gguf": splitDigest2,
},
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
}