Skip to content

Commit 07a71ae

Browse files
committed
Multimodal embedding (tested on Qwen-VL-Embedding)
1 parent 4088f7b commit 07a71ae

4 files changed

Lines changed: 308 additions & 4 deletions

File tree

llama_cpp/_internals.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,12 @@ def reset(self):
672672
if self.batch is not None:
673673
self.batch.n_tokens = 0
674674

675-
def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_all: bool):
675+
def set_batch(self,
676+
batch: Sequence[int],
677+
n_past: llama_cpp.llama_pos,
678+
logits_all: bool,
679+
logits_last: bool = True
680+
):
676681
if len(batch) > self.n_tokens_capacity:
677682
raise IndexError(f"Input batch size {len(batch)} exceeds capacity {self.n_tokens_capacity}")
678683

@@ -684,7 +689,7 @@ def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_al
684689
self.batch.seq_id[i][0] = 0
685690
self.batch.n_seq_id[i] = 1
686691
self.batch.logits[i] = logits_all
687-
self.batch.logits[n_tokens - 1] = True
692+
self.batch.logits[n_tokens - 1] = logits_last
688693

689694
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
690695
n_tokens = len(batch)

llama_cpp/llama.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858
from ._logger import set_verbose
5959
from ._utils import suppress_stdout_stderr
6060

61+
from .mtmd_cpp import mtmd_context_params_default, mtmd_init_from_file
62+
from .mtmd import MultiModalContext
63+
6164

6265
class Llama:
6366
"""High-level Python wrapper for a llama.cpp model."""
@@ -130,6 +133,10 @@ def __init__(
130133
# Misc
131134
spm_infill: bool = False,
132135
verbose: bool = True,
136+
mmproj_path: str = None,
137+
mmproj_use_gpu: Optional[bool] = None,
138+
image_min_tokens: int = -1,
139+
image_max_tokens: int = -1,
133140
# Extra Params
134141
**kwargs, # type: ignore
135142
):
@@ -426,6 +433,29 @@ def __init__(
426433
)
427434
)
428435

436+
if mmproj_path != None:
437+
mparams = mtmd_context_params_default();
438+
mparams.use_gpu = mmproj_use_gpu if mmproj_use_gpu != None else n_gpu_layers == -1
439+
mparams.print_timings = verbose
440+
mparams.n_threads = self.n_threads
441+
mparams.flash_attn_type = self.context_params.flash_attn_type
442+
mparams.warmup = True
443+
if image_min_tokens > 0:
444+
mparams.image_min_tokens = image_min_tokens
445+
if image_max_tokens > 0:
446+
mparams.image_max_tokens = image_max_tokens
447+
448+
with suppress_stdout_stderr(disable=verbose):
449+
mctx = mtmd_init_from_file(mmproj_path.encode("utf-8"), self._model.model, mparams)
450+
if mctx is None:
451+
raise RuntimeError(f"failed to load multimodal projection '{mmproj_path}'")
452+
453+
self.mtmd_context = self._stack.enter_context(
454+
contextlib.closing(
455+
MultiModalContext(mctx)
456+
)
457+
)
458+
429459
# Check for Encoder-Decoder architecture
430460
self._has_encoder = self._model.has_encoder()
431461
self._has_decoder = self._model.has_decoder()

llama_cpp/llama_embedding.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
LLAMA_POOLING_TYPE_LAST,
1313
LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models
1414
)
15+
from .mtmd import MediaChunk, mtmd_tokenize, mtmd_prefill
16+
from ._utils import suppress_stdout_stderr
1517

1618
# Normalization modes for embedding vectors
1719
# See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer
@@ -168,7 +170,7 @@ def embed(
168170
if self.verbose:
169171
llama_cpp.llama_perf_context_reset(ctx)
170172
self._batch.reset()
171-
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
173+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
172174

173175
# Initialize State Variables
174176
results: List[Any] = []
@@ -219,7 +221,7 @@ def _decode_batch():
219221
results.append(data)
220222

221223
self._batch.reset()
222-
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
224+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
223225
batch_seq_lens = []
224226

225227
# Main Streaming Loop
@@ -427,3 +429,68 @@ def create_embedding(
427429
print(f"Warning: Failed to calculate similarity matrix: {e}")
428430

429431
return response
432+
433+
434+
def embed_multimodal(
435+
self,
436+
prompt: str,
437+
files: List[bytes | str] = [],
438+
439+
normalize: int = NORM_MODE_EUCLIDEAN,
440+
return_count: bool = False,
441+
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:
442+
443+
ctx = self._ctx.ctx
444+
mctx = self.mtmd_context.ctx
445+
446+
# Determine if it is in Rerank mode
447+
try:
448+
pooling_type = self.pooling_type()
449+
except AttributeError:
450+
pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
451+
is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK)
452+
is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding
453+
454+
out_dim = self.n_embd()
455+
456+
if self.verbose:
457+
type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)")
458+
print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}")
459+
460+
# Reset Context and Batch
461+
if self.verbose:
462+
llama_cpp.llama_perf_context_reset(ctx)
463+
self._batch.reset()
464+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
465+
466+
# Initialize State Variables
467+
result: Any = None
468+
469+
470+
with suppress_stdout_stderr(disable=self.verbose):
471+
tokens: MultimodalTokenList = mtmd_tokenize(mctx, prompt, files)
472+
473+
n_tokens = len(tokens)
474+
475+
if n_tokens == 0:
476+
result = []
477+
else:
478+
n_past = mtmd_prefill(self._ctx, mctx, self._batch, tokens)
479+
480+
# Extract Embeddings
481+
ptr = llama_cpp.llama_get_embeddings_ith(ctx, self._batch.n_tokens() - 1)
482+
data = ptr[:out_dim]
483+
data = self._normalize_vector(data, normalize)
484+
485+
result = data
486+
487+
self._batch.reset()
488+
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
489+
490+
if self.verbose:
491+
llama_cpp.llama_perf_context_print(ctx)
492+
493+
if return_count:
494+
return result, n_tokens
495+
496+
return result

llama_cpp/mtmd.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import llama_cpp
2+
from llama_cpp import LLAMA_TOKEN_NULL
3+
4+
import llama_cpp.mtmd_cpp as mtmd
5+
from .mtmd_cpp import mtmd_input_chunk_type, mtmd_free
6+
from ._internals import LlamaContext, LlamaBatch
7+
8+
import ctypes
9+
from typing import Union, List
10+
11+
class TextChunk:
12+
def __init__(self, tokens: List[int]):
13+
self.tokens = tokens
14+
self.n_tokens = len(tokens)
15+
16+
class MediaChunk:
17+
def __init__(self, chunk_ptr: ctypes.c_void_p):
18+
self.chunk_ptr = mtmd.mtmd_input_chunk_copy(chunk_ptr)
19+
self.n_tokens = mtmd.mtmd_input_chunk_get_n_tokens(self.chunk_ptr)
20+
21+
def __del__(self):
22+
if self.chunk_ptr:
23+
mtmd.mtmd_input_chunk_free(self.chunk_ptr)
24+
25+
class MultimodalTokenList:
26+
def __init__(self):
27+
self.chunks: List[Union[TextChunk, MediaChunk]] = []
28+
self.total_tokens = 0
29+
30+
def add(self, chunk_ptr: mtmd.mtmd_input_chunk_p_ctypes):
31+
chunk_type = mtmd.mtmd_input_chunk_get_type(chunk_ptr)
32+
33+
if chunk_type in [
34+
mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_IMAGE,
35+
mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO
36+
]:
37+
m_chunk = MediaChunk(chunk_ptr)
38+
self.chunks.append(m_chunk)
39+
self.total_tokens += m_chunk.n_tokens
40+
41+
elif chunk_type == mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_TEXT:
42+
n_tokens_ref = ctypes.c_size_t()
43+
text_tokens_ptr = mtmd.mtmd_input_chunk_get_tokens_text(chunk_ptr, ctypes.byref(n_tokens_ref))
44+
tokens = [text_tokens_ptr[j] for j in range(n_tokens_ref.value)]
45+
self.add_text(tokens)
46+
47+
else:
48+
raise ValueError(f"Invalid chunk type {chunk_type}")
49+
50+
def add_text(self, tokens: List[int]):
51+
if not tokens: return
52+
# combine text nodes
53+
if self.chunks and isinstance(self.chunks[-1], TextChunk):
54+
self.chunks[-1].tokens.extend(tokens)
55+
self.chunks[-1].n_tokens += len(tokens)
56+
else:
57+
self.chunks.append(TextChunk(tokens))
58+
self.total_tokens += len(tokens)
59+
60+
def __len__(self):
61+
return self.total_tokens
62+
63+
64+
class MultiModalContext:
65+
def __init__(
66+
self,
67+
ctx
68+
):
69+
self.ctx = ctx
70+
71+
def close(self):
72+
if self.ctx is None:
73+
return
74+
mtmd_free(self.ctx)
75+
self.ctx = None
76+
77+
def __del__(self):
78+
self.close()
79+
80+
81+
# Simple FNV-1a hash implementation to match fnv_hash in C++
82+
def fnv_hash(data: bytes) -> str:
83+
h = 0x811c9dc5
84+
for b in data:
85+
h = (h ^ b) * 0x01000193
86+
h &= 0xffffffff
87+
return f"{h:08x}"
88+
89+
def mtmd_tokenize(
90+
mctx: mtmd.mtmd_context_p,
91+
prompt: str,
92+
files_data: list[bytes | str]) -> MultimodalTokenList:
93+
94+
bitmaps = []
95+
do_hash = False
96+
97+
for data in files_data:
98+
99+
bmp = None
100+
if isinstance(data, str):
101+
bmp = mtmd.mtmd_helper_bitmap_init_from_file(mctx, data.encode("utf-8"))
102+
elif isinstance(data, bytes):
103+
buf = (ctypes.c_ubyte * len(data)).from_buffer_copy(data)
104+
bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf))
105+
elif isinstance(data, bytearray):
106+
buf = (ctypes.c_ubyte * len(data)).from_buffer(data)
107+
bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf))
108+
109+
if bmp is None:
110+
raise RuntimeError("Failed to load image or audio file")
111+
112+
if do_hash:
113+
data_ptr = mtmd.mtmd_bitmap_get_data(bmp)
114+
data_size = mtmd.mtmd_bitmap_get_n_bytes(bmp)
115+
116+
raw_node_data = ctypes.string_at(data_ptr, data_size)
117+
h = fnv_hash(raw_node_data)
118+
mtmd.mtmd_bitmap_set_id(bmp, h.encode('utf-8'))
119+
120+
bitmaps.append(bmp)
121+
122+
inp_txt = mtmd.mtmd_input_text(
123+
text=prompt.encode('utf-8'),
124+
add_special=True,
125+
parse_special=True
126+
)
127+
128+
chunks_ptr = mtmd.mtmd_input_chunks_init()
129+
130+
n_bitmaps = len(bitmaps)
131+
if n_bitmaps > 0:
132+
BitmapPtr = mtmd.mtmd_bitmap_p_ctypes * n_bitmaps
133+
bitmaps_ptr = BitmapPtr(*bitmaps)
134+
else:
135+
bitmaps_ptr = None
136+
137+
res = mtmd.mtmd_tokenize(
138+
mctx,
139+
chunks_ptr,
140+
ctypes.pointer(inp_txt),
141+
bitmaps_ptr,
142+
n_bitmaps
143+
)
144+
145+
# TODO Hash based cache
146+
for data in bitmaps:
147+
mtmd.mtmd_bitmap_free(bmp)
148+
149+
if res != 0:
150+
mtmd.mtmd_input_chunks_free(chunks_ptr)
151+
raise RuntimeError(f"Tokenization failed with code {res}")
152+
153+
st = MultimodalTokenList()
154+
155+
n_chunks = mtmd.mtmd_input_chunks_size(chunks_ptr)
156+
for i in range(n_chunks):
157+
chunk_ptr = mtmd.mtmd_input_chunks_get(chunks_ptr, i)
158+
st.add(chunk_ptr)
159+
160+
mtmd.mtmd_input_chunks_free(chunks_ptr)
161+
162+
return st
163+
164+
def mtmd_prefill(
165+
ctx: LlamaContext,
166+
mctx: mtmd.mtmd_context_p,
167+
batch: LlamaBatch,
168+
mtmd_tokens: MultimodalTokenList
169+
) -> int:
170+
n_past = 0
171+
n_batch = ctx.n_batch()
172+
total_chunks = len(mtmd_tokens.chunks)
173+
174+
for i, chunk in enumerate(mtmd_tokens.chunks):
175+
is_last_chunk = (i == total_chunks - 1)
176+
177+
if isinstance(chunk, TextChunk):
178+
batch.set_batch(
179+
chunk.tokens,
180+
n_past,
181+
logits_all=False,
182+
logits_last=is_last_chunk
183+
)
184+
ctx.decode(batch)
185+
186+
n_past += chunk.n_tokens
187+
else:
188+
new_n_past = llama_cpp.llama_pos(0)
189+
result = mtmd.mtmd_helper_eval_chunk_single(
190+
mctx,
191+
ctx.ctx,
192+
chunk.chunk_ptr,
193+
llama_cpp.llama_pos(n_past),
194+
llama_cpp.llama_seq_id(0),
195+
n_batch,
196+
False, # logits_last
197+
ctypes.byref(new_n_past)
198+
)
199+
if result != 0:
200+
raise RuntimeError(f"MTMD eval error: {result}")
201+
202+
n_past = new_n_past.value

0 commit comments

Comments
 (0)