fix: Handle multi-chunk image encodings from mtmd

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2025-07-28 11:38:39 -04:00
parent 444c2bf248
commit 74d1f478e3
1 changed files with 22 additions and 18 deletions

View File

@ -498,27 +498,31 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
return nil, errors.New("unable to tokenize mtmd embedding from image")
}
nChunks := C.mtmd_input_chunks_size(ic)
if nChunks != 1 {
return nil, errors.New("image-only mtmd input tokenized to multiple chunks!")
}
chunk := C.mtmd_input_chunks_get(ic, 0)
// Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk")
}
// Get the embedding
embd := C.mtmd_get_output_embd(c.c)
// Copy embeddings over to go slice
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
numEmbed := llamaContext.Model().NEmbd()
s := unsafe.Slice((*float32)(embd), numEmbed*numTokens)
embed := make([][]float32, numTokens)
lastChunkSize := 0
for i := range int(nChunks) {
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
lastChunkSize = numTokens
// Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk")
}
}
// Get the embeddings
embed := make([][]float32, lastChunkSize)
embd := C.mtmd_get_output_embd(c.c)
if nil == embd {
return nil, errors.New("failed to get image embedding")
}
// Extend the embedding array for each token
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
rows := make([]float32, len(s))
copy(rows, s)
for i := range embed {
for i := range lastChunkSize {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}