Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/EN/source/cookbook/qwen35_deployment.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,11 @@ Hardware Requirements
- ``--tp 8`` required to fit model weights across GPUs
- Reduce ``--max_req_total_len`` or ``--graph_max_batch_size`` if encountering OOM errors
- Use ``--data_type fp8_e4m3`` for FP8 KV quantization to further reduce memory pressure
- Multimodal deployments get ViT OOM protection by default: when
``--enable_multimodal`` is on, ``--visual_batch_max_tokens`` is auto-derived
from ``--batch_max_tokens``. The same value caps both per-step batch
output and per-image output (oversized images are auto-resized by the
Qwen-VL ``max_pixels`` clamp; anything still over budget is rejected
before reaching the ViT). To tighten the budget further, pass an explicit
value (e.g. ``--visual_batch_max_tokens 16384``); to opt out and restore
pre-PR behavior, pass ``--visual_batch_max_tokens 0``.
21 changes: 21 additions & 0 deletions docs/EN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,27 @@ Multimodal Parameters

Number of images processed in each inference batch, default is ``1``

.. option:: --visual_batch_max_tokens

Per-step ViT admission budget, measured in image output tokens (post
spatial_merge). The multimodal analogue of ``--batch_max_tokens``: the
ViT scheduler stops adding images to the current batch once their
cumulative ``token_num`` would exceed this value. Useful for bounding
peak ViT memory on dynamic-resolution models (Qwen2.5/3/3.5-VL, etc.)
where one 4K image or long video can contain more patches than many
small images combined. One image is always admitted per step to avoid
deadlock when a single request is larger than the budget — to make that
safe, the same value also drives the per-image budget: oversized images
are auto-resized by the Qwen-VL processor ``max_pixels`` clamp, and any
image that still exceeds the budget is rejected with a ``ValueError``
before reaching the ViT.

**Default behavior with** ``--enable_multimodal``: auto-derived from
``--batch_max_tokens`` so multimodal deployments get OOM protection
without explicit opt-in. Pass an explicit positive integer to override.
Pass ``0`` to opt out and restore the pre-budget behavior (only
``--visual_infer_batch_size`` applies).

.. option:: --visual_gpu_ids

List of GPU IDs to use, e.g., 0 1 2
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen2_5_vl/qwen2_5_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from io import BytesIO
import torch.nn as nn
from transformers.activations import ACT2FN
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels
from lightllm.utils.envs_utils import get_env_start_args
from safetensors import safe_open
from lightllm.server.multimodal_params import ImageItem
from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding
Expand Down Expand Up @@ -208,6 +209,9 @@ def __init__(
with open(processor_config_path, "r") as f:
processor_config_dict = json.load(f)
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
clamp_processor_max_pixels(
self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen2_5_vl-vit"
)

self._init_datatype()

Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen2_vl/qwen2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from safetensors import safe_open
from lightllm.server.multimodal_params import ImageItem
from lightllm.server.visualserver import get_vit_attn_backend
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton

Expand Down Expand Up @@ -244,6 +245,9 @@ def load_model(self, weight_dir):
with open(processor_config_path, "r") as f:
processor_config_dict = json.load(f)
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
clamp_processor_max_pixels(
self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen2_vl-vit"
)

bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
if bin_weight_files:
Expand Down
55 changes: 55 additions & 0 deletions lightllm/models/qwen2_vl/vision_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,61 @@
logger = init_logger(__name__)


def clamp_processor_max_pixels(processor, max_image_tokens, processor_name: str = "") -> None:
"""Clamp a Qwen-VL style image processor's max-pixel limit so that even a
max-sized image produces ``token_num <= max_image_tokens``.

Reuses the processor's built-in ``smart_resize`` mechanism — just tightens
the per-pixel budget so the existing resize path fits the server-wide
per-image token budget (``--visual_batch_max_tokens``). After the clamp,
``get_image_token_length`` cannot return a value above the budget, so
request-level rejection becomes a defensive no-op in practice.

Different Qwen-VL generations expose the limit on different attributes:
Qwen2-VL / Qwen2.5-VL / lightllm's own ``Qwen2VLImageProcessor`` use
``processor.max_pixels``, while HF's Qwen3-VL / Qwen3.5-VL processors store
it in ``processor.size["longest_edge"]``. Both attributes are clamped when
present so any reader (HF runtime, tokenizer ``__init__``) sees the
tightened bound.

No-op when ``max_image_tokens`` is None or the processor already enforces a
tighter bound.
"""
if max_image_tokens is None:
return
unit = processor.patch_size * processor.merge_size
allowed_max_pixels = max_image_tokens * unit * unit
if allowed_max_pixels < unit * unit:
raise ValueError(
f"max_image_tokens={max_image_tokens} is too small; "
f"need at least 1 patch's worth (={unit * unit} pixels) for {processor_name or 'processor'}."
)

# Track originals so the log line shows the pre-clamp values; some
# processors only expose one of the two schemas, so each branch is gated
# on its own attribute presence.
current_max_pixels = getattr(processor, "max_pixels", None)
size = getattr(processor, "size", None)
has_longest_edge = isinstance(size, dict) and "longest_edge" in size
current_longest_edge = size.get("longest_edge") if has_longest_edge else None

clamped = False
if current_max_pixels is None or allowed_max_pixels < current_max_pixels:
processor.max_pixels = allowed_max_pixels
clamped = True
if has_longest_edge and (current_longest_edge is None or allowed_max_pixels < current_longest_edge):
size["longest_edge"] = allowed_max_pixels
clamped = True

if clamped:
logger.info(
f"{processor_name or 'processor'}: clamping max_pixels/longest_edge to "
f"{allowed_max_pixels} (was max_pixels={current_max_pixels}, "
f"longest_edge={current_longest_edge}; "
f"max_image_tokens={max_image_tokens}, unit={unit})"
)


IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from lightllm.server.multimodal_params import ImageItem
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention


Expand Down Expand Up @@ -225,6 +226,9 @@ def load_model(self, weight_dir):
with open(processor_config_path, "r") as f:
processor_config_dict = json.load(f)
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
clamp_processor_max_pixels(
self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen3_omni-vit"
)

bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
if bin_weight_files:
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen3_vl/qwen3_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

from lightllm.server.multimodal_params import ImageItem
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor, clamp_processor_max_pixels
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention
from lightllm.utils.log_utils import init_logger

Expand Down Expand Up @@ -220,6 +221,9 @@ def load_model(self, weight_dir):
with open(processor_config_path, "r") as f:
processor_config_dict = json.load(f)
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
clamp_processor_max_pixels(
self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="qwen3_vl-vit"
)

bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
if bin_weight_files:
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/tarsier2/tarsier2_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.server.multimodal_params import ImageItem
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image
from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image, clamp_processor_max_pixels
from lightllm.utils.envs_utils import get_env_start_args


def add_split_tokens(image_features, image_newline_embed, image_new_embed):
Expand Down Expand Up @@ -221,6 +222,9 @@ def load_model(self, weight_dir):
with open(processor_config_path, "r") as f:
processor_config_dict = json.load(f)
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
clamp_processor_max_pixels(
self.processor, get_env_start_args().visual_batch_max_tokens, processor_name="tarsier2-vit"
)

bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
if bin_weight_files:
Expand Down
24 changes: 24 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,30 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--visual_infer_batch_size", type=int, default=None, help="number of images to process in each inference batch"
)
parser.add_argument(
"--visual_batch_max_tokens",
type=int,
default=None,
help="""
Per-step ViT admission budget measured in image output tokens (post
spatial_merge). The ViT scheduler stops adding images to the current
batch once their cumulative token_num would exceed this value. Acts as
the multimodal analogue of --batch_max_tokens and caps peak ViT
memory/compute for dynamic-resolution models (Qwen2.5/3/3.5-VL, etc.).
One image is always admitted per step to avoid deadlock when a single
request is larger than the budget — to make that safe, this value
also drives the per-image budget: oversized images are auto-resized
by the Qwen-VL processor max_pixels clamp, and any image that still
exceeds the budget is rejected with a ValueError before reaching the
ViT.

Default behavior when --enable_multimodal is on: auto-derived from
--batch_max_tokens so multimodal deployments get OOM protection without
explicit opt-in. Pass an explicit positive integer to override; pass 0
to opt out entirely and restore the pre-budget behavior (only the
image-count cap --visual_infer_batch_size applies).
""",
)
parser.add_argument(
"--visual_send_batch_size",
type=int,
Expand Down
18 changes: 18 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ def normal_or_p_d_start(args):
args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num
logger.info(f"set cpu_cache_token_page_size to {args.cpu_cache_token_page_size} for linear hybrid att model")

# 多模态预算默认值(safety-on-by-default for multimodal deployments):
# - 不传:visual_batch_max_tokens 默认等于 batch_max_tokens(LLM 和 ViT 共用预算口径)。
# - 传 0:显式关闭,恢复 PR 之前的"不限"行为(向后兼容用)。
# - 传正整数:作为显式预算使用。
# 同一个值同时充当 per-step batch budget、per-image hard cap 和 processor max_pixels
# clamp 的依据 —— "首图必放行" 规则要求单图必须能塞进一个批次,所以 batch budget 和
# 单图上限本来就是同一个数。
if args.enable_multimodal:
if args.visual_batch_max_tokens is None:
args.visual_batch_max_tokens = args.batch_max_tokens
logger.info(
f"visual_batch_max_tokens auto-derived from batch_max_tokens = {args.batch_max_tokens} "
f"(pass --visual_batch_max_tokens 0 to opt out)"
)
elif args.visual_batch_max_tokens == 0:
logger.info("visual_batch_max_tokens explicitly disabled (=0); ViT token budget off")
args.visual_batch_max_tokens = None

# help to manage data stored on Ceph
if "s3://" in args.model_dir:
from lightllm.utils.petrel_helper import s3_model_prepare
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class StartArgs:
push_interval: int = field(default=10)
visual_node_id: int = field(default=None)
visual_infer_batch_size: int = field(default=None)
visual_batch_max_tokens: Optional[int] = field(default=None)
visual_send_batch_size: int = field(default=1)
visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
visual_tp: int = field(default=1)
Expand Down
10 changes: 7 additions & 3 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from lightllm.utils.config_utils import get_vocab_size
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
from lightllm.utils.multimodal_utils import enforce_image_token_budget
from rpyc.utils.classic import obtain

logger = init_logger(__name__)
Expand Down Expand Up @@ -179,11 +180,12 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
if self.pd_mode.is_P_or_NORMAL():
items, md5sums, tokens_nums, datas = [], [], [], []
for img in multimodal_params.images:
for img_index, img in enumerate(multimodal_params.images):
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params)
data = img.read()
# must after init_imageitem_extral_params
token_num = self.tokenizer.get_image_token_length(img)
enforce_image_token_budget(token_num, self.args.visual_batch_max_tokens, image_index=img_index)
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params)))
md5sums.append(md5sum)
img.md5 = md5sum
Expand Down Expand Up @@ -236,10 +238,12 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
img_count = 0
audio_tokens = 0
audio_count = 0
for img in multimodal_params.images:
for img_index, img in enumerate(multimodal_params.images):
img_count += 1
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params)
image_tokens += self.tokenizer.get_image_token_length(img)
token_num = self.tokenizer.get_image_token_length(img)
enforce_image_token_budget(token_num, self.args.visual_batch_max_tokens, image_index=img_index)
image_tokens += token_num
for audio in multimodal_params.audios:
audio_count += 1
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params)
Expand Down
7 changes: 5 additions & 2 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from lightllm.utils.envs_utils import get_pd_split_max_new_tokens
from lightllm.utils.multimodal_utils import enforce_image_token_budget
from .pd_selector import create_selector

logger = init_logger(__name__)
Expand Down Expand Up @@ -73,10 +74,12 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
img_count = 0
audio_tokens = 0
audio_count = 0
for img in multimodal_params.images:
for img_index, img in enumerate(multimodal_params.images):
img_count += 1
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params)
image_tokens += self.tokenizer.get_image_token_length(img)
token_num = self.tokenizer.get_image_token_length(img)
enforce_image_token_budget(token_num, self.args.visual_batch_max_tokens, image_index=img_index)
image_tokens += token_num
for audio in multimodal_params.audios:
audio_count += 1
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params)
Expand Down
Loading
Loading