From 74d1f478e35b41693e1acc51c34715869b748469 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 28 Jul 2025 11:38:39 -0400 Subject: [PATCH] fix: Handle multi-chunk image encodings from mtmd Branch: GraniteFour Signed-off-by: Gabe Goodhart --- llama/llama.go | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 31fdba69b..4885949b7 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -498,27 +498,31 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, return nil, errors.New("unable to tokenize mtmd embedding from image") } nChunks := C.mtmd_input_chunks_size(ic) - if nChunks != 1 { - return nil, errors.New("image-only mtmd input tokenized to multiple chunks!") - } - chunk := C.mtmd_input_chunks_get(ic, 0) - - // Encode the chunk - if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { - return nil, errors.New("unable to encode mtmd image chunk") - } - - // Get the embedding - embd := C.mtmd_get_output_embd(c.c) - - // Copy embeddings over to go slice - numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) numEmbed := llamaContext.Model().NEmbd() - s := unsafe.Slice((*float32)(embd), numEmbed*numTokens) - embed := make([][]float32, numTokens) + lastChunkSize := 0 + for i := range int(nChunks) { + chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) + numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) + lastChunkSize = numTokens + + // Encode the chunk + if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { + return nil, errors.New("unable to encode mtmd image chunk") + } + } + + // Get the embeddings + embed := make([][]float32, lastChunkSize) + embd := C.mtmd_get_output_embd(c.c) + if nil == embd { + return nil, errors.New("failed to get image embedding") + } + + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) rows := make([]float32, len(s)) copy(rows, s) - for i := range embed { + for i := range lastChunkSize { embed[i] = rows[i*numEmbed : (i+1)*numEmbed] }