Skip to content

Commit 37fbbdf

Browse files
committed
Multimodal embedding (tested on Qwen-VL-Embedding)
1 parent 4bfe5ed commit 37fbbdf

4 files changed

Lines changed: 318 additions & 20 deletions

File tree

llama_cpp/_internals.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -699,25 +699,24 @@ def reset(self):
699699
if self.batch is not None:
700700
self.batch.n_tokens = 0
701701

702-
def add_token(self, token: int, pos: int, seq_ids: Sequence[int], logits: bool):
703-
"""
704-
Adds a single token to the batch.
705-
This is a high-performance method for appending a single token during the generation loop,
706-
avoiding the overhead of creating temporary lists required by add_sequence.
707-
708-
Args:
709-
token: The integer ID of the token to add.
710-
pos: The logical sequence position (n_past) of this token.
711-
seq_ids: A sequence of sequence IDs this token belongs to (e.g., [0] for a standard single chat).
712-
A single token can be part of multiple sequences simultaneously.
713-
logits: A boolean flag indicating whether the backend should compute logits for this token.
714-
"""
715-
idx = self.batch.n_tokens
716-
if idx >= self.n_tokens_capacity:
717-
raise IndexError(f"LlamaBatch overflow[add_token]: Cannot add token. Capacity {self.n_tokens_capacity} reached.")
702+
def set_batch(self,
703+
batch: Sequence[int],
704+
n_past: llama_cpp.llama_pos,
705+
logits_all: bool,
706+
logits_last: bool = True
707+
):
708+
if len(batch) > self.n_tokens_capacity:
709+
raise IndexError(f"Input batch size {len(batch)} exceeds capacity {self.n_tokens_capacity}")
718710

719-
self.batch.token[idx] = token
720-
self.batch.pos[idx] = pos
711+
n_tokens = len(batch)
712+
self.batch.n_tokens = n_tokens
713+
for i in range(n_tokens):
714+
self.batch.token[i] = batch[i]
715+
self.batch.pos[i] = n_past + i
716+
self.batch.seq_id[i][0] = 0
717+
self.batch.n_seq_id[i] = 1
718+
self.batch.logits[i] = logits_all
719+
self.batch.logits[n_tokens - 1] = logits_last
721720

722721
n_seq_id = len(seq_ids)
723722
if n_seq_id > self.n_seq_max:

llama_cpp/llama.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
from ._logger import set_verbose
6060
from ._utils import suppress_stdout_stderr
6161

62+
from .mtmd_cpp import mtmd_context_params_default, mtmd_init_from_file
63+
from .mtmd import MultiModalContext
64+
6265

6366
class Llama:
6467
"""High-level Python wrapper for a llama.cpp model."""
@@ -135,6 +138,10 @@ def __init__(
135138
# Misc
136139
spm_infill: bool = False,
137140
verbose: bool = True,
141+
mmproj_path: str = None,
142+
mmproj_use_gpu: Optional[bool] = None,
143+
image_min_tokens: int = -1,
144+
image_max_tokens: int = -1,
138145
# Extra Params
139146
**kwargs, # type: ignore
140147
):
@@ -435,6 +442,29 @@ def __init__(
435442
)
436443
)
437444

445+
if mmproj_path != None:
446+
mparams = mtmd_context_params_default();
447+
mparams.use_gpu = mmproj_use_gpu if mmproj_use_gpu != None else n_gpu_layers == -1
448+
mparams.print_timings = verbose
449+
mparams.n_threads = self.n_threads
450+
mparams.flash_attn_type = self.context_params.flash_attn_type
451+
mparams.warmup = True
452+
if image_min_tokens > 0:
453+
mparams.image_min_tokens = image_min_tokens
454+
if image_max_tokens > 0:
455+
mparams.image_max_tokens = image_max_tokens
456+
457+
with suppress_stdout_stderr(disable=verbose):
458+
mctx = mtmd_init_from_file(mmproj_path.encode("utf-8"), self._model.model, mparams)
459+
if mctx is None:
460+
raise RuntimeError(f"failed to load multimodal projection '{mmproj_path}'")
461+
462+
self.mtmd_context = self._stack.enter_context(
463+
contextlib.closing(
464+
MultiModalContext(mctx)
465+
)
466+
)
467+
438468
# Check for Encoder-Decoder architecture
439469
self._has_encoder = self._model.has_encoder()
440470
self._has_decoder = self._model.has_decoder()

llama_cpp/llama_embedding.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
LLAMA_POOLING_TYPE_LAST,
1313
LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models
1414
)
15+
from .mtmd import MediaChunk, mtmd_tokenize, mtmd_prefill
16+
from ._utils import suppress_stdout_stderr
1517

1618
# Normalization modes for embedding vectors
1719
# See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer
@@ -168,7 +170,7 @@ def embed(
168170
if self.verbose:
169171
llama_cpp.llama_perf_context_reset(ctx)
170172
self._batch.reset()
171-
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
173+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
172174

173175
# Initialize State Variables
174176
results: List[Any] = []
@@ -219,7 +221,7 @@ def _decode_batch():
219221
results.append(data)
220222

221223
self._batch.reset()
222-
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
224+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
223225
batch_seq_lens = []
224226

225227
# Main Streaming Loop
@@ -439,3 +441,68 @@ def create_embedding(
439441
print(f"Warning: Failed to calculate similarity matrix: {e}")
440442

441443
return response
444+
445+
446+
def embed_multimodal(
447+
self,
448+
prompt: str,
449+
files: List[bytes | str] = [],
450+
451+
normalize: int = NORM_MODE_EUCLIDEAN,
452+
return_count: bool = False,
453+
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:
454+
455+
ctx = self._ctx.ctx
456+
mctx = self.mtmd_context.ctx
457+
458+
# Determine if it is in Rerank mode
459+
try:
460+
pooling_type = self.pooling_type()
461+
except AttributeError:
462+
pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
463+
is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK)
464+
is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding
465+
466+
out_dim = self.n_embd()
467+
468+
if self.verbose:
469+
type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)")
470+
print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}")
471+
472+
# Reset Context and Batch
473+
if self.verbose:
474+
llama_cpp.llama_perf_context_reset(ctx)
475+
self._batch.reset()
476+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
477+
478+
# Initialize State Variables
479+
result: Any = None
480+
481+
482+
with suppress_stdout_stderr(disable=self.verbose):
483+
tokens: MultimodalTokenList = mtmd_tokenize(mctx, prompt, files)
484+
485+
n_tokens = len(tokens)
486+
487+
if n_tokens == 0:
488+
result = []
489+
else:
490+
n_past = mtmd_prefill(self._ctx, mctx, self._batch, tokens)
491+
492+
# Extract Embeddings
493+
ptr = llama_cpp.llama_get_embeddings_ith(ctx, self._batch.n_tokens() - 1)
494+
data = ptr[:out_dim]
495+
data = self._normalize_vector(data, normalize)
496+
497+
result = data
498+
499+
self._batch.reset()
500+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
501+
502+
if self.verbose:
503+
llama_cpp.llama_perf_context_print(ctx)
504+
505+
if return_count:
506+
return result, n_tokens
507+
508+
return result

llama_cpp/mtmd.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import llama_cpp
2+
from llama_cpp import LLAMA_TOKEN_NULL
3+
4+
import llama_cpp.mtmd_cpp as mtmd
5+
from .mtmd_cpp import mtmd_input_chunk_type, mtmd_free
6+
from ._internals import LlamaContext, LlamaBatch
7+
8+
import ctypes
9+
from typing import Union, List
10+
11+
class TextChunk:
12+
def __init__(self, tokens: List[int]):
13+
self.tokens = tokens
14+
self.n_tokens = len(tokens)
15+
16+
class MediaChunk:
17+
def __init__(self, chunk_ptr: ctypes.c_void_p):
18+
self.chunk_ptr = mtmd.mtmd_input_chunk_copy(chunk_ptr)
19+
self.n_tokens = mtmd.mtmd_input_chunk_get_n_tokens(self.chunk_ptr)
20+
21+
def __del__(self):
22+
if self.chunk_ptr:
23+
mtmd.mtmd_input_chunk_free(self.chunk_ptr)
24+
25+
class MultimodalTokenList:
26+
def __init__(self):
27+
self.chunks: List[Union[TextChunk, MediaChunk]] = []
28+
self.total_tokens = 0
29+
30+
def add(self, chunk_ptr: mtmd.mtmd_input_chunk_p_ctypes):
31+
chunk_type = mtmd.mtmd_input_chunk_get_type(chunk_ptr)
32+
33+
if chunk_type in [
34+
mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_IMAGE,
35+
mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO
36+
]:
37+
m_chunk = MediaChunk(chunk_ptr)
38+
self.chunks.append(m_chunk)
39+
self.total_tokens += m_chunk.n_tokens
40+
41+
elif chunk_type == mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_TEXT:
42+
n_tokens_ref = ctypes.c_size_t()
43+
text_tokens_ptr = mtmd.mtmd_input_chunk_get_tokens_text(chunk_ptr, ctypes.byref(n_tokens_ref))
44+
tokens = [text_tokens_ptr[j] for j in range(n_tokens_ref.value)]
45+
self.add_text(tokens)
46+
47+
else:
48+
raise ValueError(f"Invalid chunk type {chunk_type}")
49+
50+
def add_text(self, tokens: List[int]):
51+
if not tokens: return
52+
# combine text nodes
53+
if self.chunks and isinstance(self.chunks[-1], TextChunk):
54+
self.chunks[-1].tokens.extend(tokens)
55+
self.chunks[-1].n_tokens += len(tokens)
56+
else:
57+
self.chunks.append(TextChunk(tokens))
58+
self.total_tokens += len(tokens)
59+
60+
def __len__(self):
61+
return self.total_tokens
62+
63+
64+
class MultiModalContext:
65+
def __init__(
66+
self,
67+
ctx
68+
):
69+
self.ctx = ctx
70+
71+
def close(self):
72+
if self.ctx is None:
73+
return
74+
mtmd_free(self.ctx)
75+
self.ctx = None
76+
77+
def __del__(self):
78+
self.close()
79+
80+
81+
# Simple FNV-1a hash implementation to match fnv_hash in C++
82+
def fnv_hash(data: bytes) -> str:
83+
h = 0x811c9dc5
84+
for b in data:
85+
h = (h ^ b) * 0x01000193
86+
h &= 0xffffffff
87+
return f"{h:08x}"
88+
89+
def mtmd_tokenize(
90+
mctx: mtmd.mtmd_context_p,
91+
prompt: str,
92+
files_data: list[bytes | str]) -> MultimodalTokenList:
93+
94+
bitmaps = []
95+
do_hash = False
96+
97+
for data in files_data:
98+
99+
bmp = None
100+
if isinstance(data, str):
101+
bmp = mtmd.mtmd_helper_bitmap_init_from_file(mctx, data.encode("utf-8"))
102+
elif isinstance(data, bytes):
103+
buf = (ctypes.c_ubyte * len(data)).from_buffer_copy(data)
104+
bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf))
105+
elif isinstance(data, bytearray):
106+
buf = (ctypes.c_ubyte * len(data)).from_buffer(data)
107+
bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf))
108+
109+
if bmp is None:
110+
raise RuntimeError("Failed to load image or audio file")
111+
112+
if do_hash:
113+
data_ptr = mtmd.mtmd_bitmap_get_data(bmp)
114+
data_size = mtmd.mtmd_bitmap_get_n_bytes(bmp)
115+
116+
raw_node_data = ctypes.string_at(data_ptr, data_size)
117+
h = fnv_hash(raw_node_data)
118+
mtmd.mtmd_bitmap_set_id(bmp, h.encode('utf-8'))
119+
120+
bitmaps.append(bmp)
121+
122+
inp_txt = mtmd.mtmd_input_text(
123+
text=prompt.encode('utf-8'),
124+
add_special=True,
125+
parse_special=True
126+
)
127+
128+
chunks_ptr = mtmd.mtmd_input_chunks_init()
129+
130+
n_bitmaps = len(bitmaps)
131+
if n_bitmaps > 0:
132+
BitmapPtr = mtmd.mtmd_bitmap_p_ctypes * n_bitmaps
133+
bitmaps_ptr = BitmapPtr(*bitmaps)
134+
else:
135+
bitmaps_ptr = None
136+
137+
res = mtmd.mtmd_tokenize(
138+
mctx,
139+
chunks_ptr,
140+
ctypes.pointer(inp_txt),
141+
bitmaps_ptr,
142+
n_bitmaps
143+
)
144+
145+
# TODO Hash based cache
146+
for data in bitmaps:
147+
mtmd.mtmd_bitmap_free(bmp)
148+
149+
if res != 0:
150+
mtmd.mtmd_input_chunks_free(chunks_ptr)
151+
raise RuntimeError(f"Tokenization failed with code {res}")
152+
153+
st = MultimodalTokenList()
154+
155+
n_chunks = mtmd.mtmd_input_chunks_size(chunks_ptr)
156+
for i in range(n_chunks):
157+
chunk_ptr = mtmd.mtmd_input_chunks_get(chunks_ptr, i)
158+
st.add(chunk_ptr)
159+
160+
mtmd.mtmd_input_chunks_free(chunks_ptr)
161+
162+
return st
163+
164+
def mtmd_prefill(
165+
ctx: LlamaContext,
166+
mctx: mtmd.mtmd_context_p,
167+
batch: LlamaBatch,
168+
mtmd_tokens: MultimodalTokenList
169+
) -> int:
170+
n_past = 0
171+
n_batch = ctx.n_batch()
172+
total_chunks = len(mtmd_tokens.chunks)
173+
174+
for i, chunk in enumerate(mtmd_tokens.chunks):
175+
is_last_chunk = (i == total_chunks - 1)
176+
177+
if isinstance(chunk, TextChunk):
178+
batch.set_batch(
179+
chunk.tokens,
180+
n_past,
181+
logits_all=False,
182+
logits_last=is_last_chunk
183+
)
184+
ctx.decode(batch)
185+
186+
n_past += chunk.n_tokens
187+
else:
188+
new_n_past = llama_cpp.llama_pos(0)
189+
result = mtmd.mtmd_helper_eval_chunk_single(
190+
mctx,
191+
ctx.ctx,
192+
chunk.chunk_ptr,
193+
llama_cpp.llama_pos(n_past),
194+
llama_cpp.llama_seq_id(0),
195+
n_batch,
196+
False, # logits_last
197+
ctypes.byref(new_n_past)
198+
)
199+
if result != 0:
200+
raise RuntimeError(f"MTMD eval error: {result}")
201+
202+
n_past = new_n_past.value

0 commit comments

Comments
 (0)