Skip to content
14 changes: 14 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Llama:
def __init__(
self,
model_path: str,
clip_model_path: Optional[str] = None,
*,
# Model Params
n_gpu_layers: Union[int, Literal["auto", "all"]] = "auto",
Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(
log_filters: Optional[Sequence[str]] = None,
log_filters_case_sensitive: bool = True,
# Extra Params
chat_handler_kwargs: Dict[str, Any] = {},
**kwargs, # type: ignore
):
"""Load a llama.cpp model from `model_path`.
Expand Down Expand Up @@ -706,6 +708,18 @@ def __init__(
print(f"Failed to load metadata: {e}", file=sys.stderr)

if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)

if clip_model_path is not None:
if self.chat_handler is not None and self.verbose:
print("Warning: Both `chat_handler` and `clip_model_path` are not null. Chat handler will be overwritten.", flush = True)

self.chat_handler = llama_chat_format.GenericMTMDChatHandler(
gguf_metadata = self.metadata,
clip_model_path = clip_model_path,
verbose = self.verbose,
**chat_handler_kwargs
)
print(f"Model desc: {self.model_desc}, "
f"Model size: {self.model_size / (1024 * 1024):.2f} MB, "
f"Model metadata: {self.metadata}",
Expand Down
56 changes: 52 additions & 4 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2887,10 +2887,14 @@ def __init__(
raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {clip_model_path}")

# Pre-compile Jinja template
if not hasattr(self, "chat_format") or self.chat_format is None:
self.chat_format = self.CHAT_FORMAT

self._chat_format_parser_tags = []
self.chat_template = ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
).from_string(self.CHAT_FORMAT)
).from_string(self.chat_format)

self._exit_stack = ExitStack()

Expand Down Expand Up @@ -2992,13 +2996,13 @@ def _get_media_items(self, messages: List[llama_types.ChatCompletionRequestMessa
media_items.append({"url": url, "type": "image"})

# 2. Audio Processing
elif content_type in ["audio_url", "input_audio"]:
elif content_type in ["audio", "audio_url", "input_audio"]:
if not self.is_support_audio:
raise ValueError(f"{self.log_prefix}: This mmproj model instance does not support audio inputs.")

# Case A: Handle custom/forward-compatible audio_url format
if content_type == "audio_url":
audio_url = content["audio_url"]
if content_type == "audio_url" or content_type == "audio":
audio_url = content[content_type]
url = audio_url if isinstance(audio_url, str) else audio_url["url"]
media_items.append({"url": url, "type": "audio"})
# Case B: Handle OpenAI standard input_audio format
Expand Down Expand Up @@ -3117,6 +3121,13 @@ def _process_mtmd_prompt(
tool_choice=tool_choice,
**getattr(self, 'extra_template_arguments', {})
)

for tag in self._chat_format_parser_tags:
if tag not in text:
continue

text = text.replace(tag, media_marker)

# Replace image_url by media_marker in text
for item in media_items:
text = text.replace(item["url"], media_marker)
Expand Down Expand Up @@ -3828,6 +3839,43 @@ def from_pretrained(
**kwargs,
)

class GenericMTMDChatHandler(MTMDChatHandler):
KNOWN_MEDIA_TAGS = [
"<|image_pad|>",
"<|audio_pad|>",
"<|video_pad|>",
"<|image|>",
"<|audio|>",
"<|video|>",
"[IMG]"
]

def __init__(
self,
gguf_metadata: Dict[str, Any],
clip_model_path: str,
verbose: bool = True,
**kwargs
) -> None:
self.model_metadata = gguf_metadata
self.chat_format = self.model_metadata.get("tokenizer.chat_template", None)

if verbose:
print(f"Got chat template from model:\n```jinja\n{self.chat_format}\n```", flush = True)

if self.chat_format is None:
raise ValueError("Failed to get model chat template automatically.")

super().__init__(clip_model_path = clip_model_path, verbose = verbose, **kwargs)

def __call__(self, **kwargs):
self._chat_format_parser_tags = [tag for tag in self.KNOWN_MEDIA_TAGS if tag in self.chat_format]

if self.verbose:
print(f"{self.log_prefix} - Start processing")

# Use parent implementation
return super().__call__(**kwargs)

class Llava15ChatHandler(MTMDChatHandler):
CHAT_FORMAT = (
Expand Down
Loading