ViT/multimodal token-budget admission + max_pixels clamp#1290
ViT/multimodal token-budget admission + max_pixels clamp#1290sufubao wants to merge 4 commits intoModelTC:mainfrom
Conversation
… processor max_pixels clamp Introduces three layers of OOM protection for dynamic-resolution multimodal models (Qwen2.5/3/3.5-VL, qwen3-omni, tarsier2), analogous to the LLM side's batch_max_tokens + max_req_total_len pair: - --visual_batch_max_tokens: per-step ViT admission budget measured in image output tokens (post spatial_merge). pull_batch_with_budget in the ViT scheduler stops adding images once cumulative token_num would exceed the budget. The first image is always admitted to avoid deadlock. - --visual_image_max_tokens: per-image hard cap. When multimodal is enabled, defaults to visual_batch_max_tokens, which in turn defaults to batch_max_tokens. Asserts visual_image_max_tokens <= visual_batch_max_tokens at startup. - clamp_processor_max_pixels: at tokenizer/ViT load time, tightens the HF image processor's max_pixels so smart_resize auto-clamps any oversized image to fit the per-image token budget. Runtime token_num > budget becomes structurally impossible for Qwen-VL-family models. Treats processor.max_pixels=None (e.g. Qwen3.5-VL AutoProcessor) as "looser than any budget" and always applies the computed allowed_max_pixels. enforce_image_token_budget remains as defense-in-depth in the HTTP managers, firing for InternVL and any non-processor-clamped path. Also fixes a heartbeat regression in the ViT infer worker: rank 0 used to block indefinitely on the infer queue while non-zero ranks waited on a gloo broadcast with a 30-minute timeout, crashing the workers during idle periods. Rank 0 now polls with a 60s timeout and broadcasts an empty batch as a heartbeat so all ranks stay in sync. Tests: batching, budget validation, processor clamp, and budget enforcement (20+ tests, all pass).
There was a problem hiding this comment.
Code Review
This pull request introduces a multimodal token budget system to prevent OOM errors in dynamic-resolution models by adding --visual_batch_max_tokens for per-step ViT admission and --visual_image_max_tokens for per-image hard caps. The implementation includes model-level processor clamping, request-level enforcement, and a budget-aware scheduler for the ViT. Review feedback identified a critical synchronization issue where returning tasks to the shared queue in the batching logic causes desync across Tensor Parallel ranks. Additionally, it was noted that a hardcoded batch size limit in the manager currently restricts the scheduler to single-image batches, and the test suite should be updated to verify strict FIFO ordering to prevent future regressions.
| infer_queue.put(task) | ||
| break | ||
|
|
||
| next_tokens = task.token_num or 0 | ||
| if max_tokens is not None and total_tokens + next_tokens > max_tokens: | ||
| infer_queue.put(task) | ||
| semaphore.release() |
There was a problem hiding this comment.
Using infer_queue.put(task) to return an item to the queue causes a critical synchronization issue in multi-GPU (Tensor Parallel) setups.
In lightllm, the infer_queue is populated identically across all TP ranks. Rank 0 performs the admission logic and broadcasts the final batch size to other ranks. If Rank 0 pulls a task, finds it exceeds the budget, and puts it back, that task is moved to the end of Rank 0's queue. However, other ranks (which only pull the number of items Rank 0 actually admitted) will still have that task at the front of their respective queues. This leads to a desync in the next inference step, where Rank 0 processes the next available item while other ranks process the item that was put back.
To fix this, Rank 0 should avoid returning items to the shared queue. Instead, it should maintain a local buffer for "leftover" tasks and check this buffer before pulling from the queue in subsequent steps.
| "quant_type": self.args.vit_quant_type, | ||
| "quant_cfg": self.args.vit_quant_cfg, | ||
| "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), | ||
| "visual_batch_max_tokens": self.args.visual_batch_max_tokens, |
There was a problem hiding this comment.
The max_batch_size passed to the ViT worker is currently capped at 1 due to the min(..., 1) logic on line 87. This effectively disables the new token-budget admission feature for batching multiple images, as the scheduler will never attempt to pull more than one image per step.
If the intention of this PR is to allow batching multiple small images within a token budget, this cap should be removed or adjusted (e.g., to self.infer_batch_size // self.vit_dp).
| self.assertIn(300, remaining) | ||
| self.assertIn(400, remaining) |
There was a problem hiding this comment.
The use of assertIn here masks the reordering issue mentioned in the pull_batch_with_budget feedback. If the queue order is not strictly FIFO after a budget cutoff, it will cause desyncs in distributed environments.
Please update the test to verify the exact order of the remaining items in the queue using assertEqual to ensure that items are returned to the front (or handled via a buffer) rather than appended to the end.
… clamp Two correctness fixes for the visual token-budget admission landed earlier in this PR. 1. TP rank divergence on budget skip. ``pull_batch_with_budget`` is rank-0- only; non-zero ranks pop ``len(images)`` items from their own identical FIFO queues to follow rank 0's decision. The old ``infer_queue.put(task)`` on a budget/sem-permit reject appended the rejected item to the *tail* of rank 0's queue, so subsequent calls saw a different FIFO order than rank N's queues — different images encoded on different ranks, corrupting TP visual inference. Add a ``_put_front`` helper that re-inserts at the head under ``infer_queue.mutex`` and use it on both reject paths. Regression tests added for the budget-skip and sem-skip orderings. 2. Qwen3-VL / Qwen3.5-VL tokenizer used pre-clamp budget. HF's Qwen3-VL image processor stores the per-image limit in ``processor.size["longest_edge"]``, and ``QWen3VLTokenizer.__init__`` reads that key into ``self.max_pixel``. ``clamp_processor_max_pixels`` only wrote ``processor.max_pixels``, so the tokenizer's ``get_image_token_length`` kept using the original (huge) limit. With the default ``visual_image_max_tokens = batch_max_tokens``, large images were wrongly rejected by ``enforce_image_token_budget`` even though the ViT processor would have resized them in-budget. Clamp both ``processor.max_pixels`` and ``processor.size["longest_edge"]`` (when present) so every reader sees the same tightened bound. Tests: 21/21 pass for the two affected files.
Two follow-ups from a second review pass.
1. Visual budget defaults are auto-derived (not None) when --enable_multimodal
is on, but the CLI help and docs still claimed "Default: None (disabled)".
This was a user-visible behavior mismatch and existing deployments had no
documented way to keep the old unlimited behavior.
- Accept ``0`` as an explicit opt-out sentinel for both
--visual_batch_max_tokens and --visual_image_max_tokens; api_start.py
converts the 0 back to None internally so all downstream consumers
(pull_batch_with_budget, clamp_processor_max_pixels) treat it as
disabled.
- Rewrite the CLI help and docs (api_server_args.rst, qwen35_deployment.rst)
to describe the actual default (auto-derived from --batch_max_tokens) and
point to the 0 opt-out for backward compatibility.
- Loosen the visual_image_max_tokens <= visual_batch_max_tokens assertion
to skip when either is None (the disabled case).
2. ``_put_front`` was bumping ``infer_queue.unfinished_tasks`` after manually
re-inserting an item that ``Queue.put`` had already counted. ``Queue.get``
does not decrement that counter (only ``task_done()`` does), so the bump
would corrupt the queue invariant and make any future ``Queue.join()`` /
``task_done()`` usage hang. Today ``visualserver`` does not call
``join``/``task_done``, so this is latent — but the contract was wrong.
- Drop the ``unfinished_tasks += 1`` line; keep the lock and the
``not_empty.notify()`` so blocked consumers still wake up.
- Add a regression test that put → get → reject → task_done loop and
calls ``Queue.join()``; ``join`` would hang under the old code.
Tests: 22/22 pass on the affected files.
…ual_batch_max_tokens The per-image cap was always defaulted to ``visual_batch_max_tokens`` and asserted ``<=`` it, because the "first image always admitted" deadlock-avoidance rule implies a single image must fit in one batch. Two knobs for the same value was just surface-area without a real degree of freedom. Collapse to one knob: - Remove ``--visual_image_max_tokens`` from CLI, ``StartArgs``, ``api_start.py``, and the inter-block validation that enforced the redundant ``<=`` invariant. - All clamp call sites (qwen2_vl / qwen2_5_vl / qwen3_vl / qwen3_omni / tarsier2 visual loaders, and ``server/tokenizer.py``) now read ``visual_batch_max_tokens``. - ``enforce_image_token_budget`` (httpserver + httpserver_for_pd_master) and the error message in ``multimodal_utils.py`` now reference ``visual_batch_max_tokens``. - Rename the helper's parameter from ``visual_image_max_tokens`` to ``max_image_tokens`` since it no longer mirrors a CLI flag — it's just the numeric budget the caller computes from ``visual_batch_max_tokens``. - Update CLI help and docs (``api_server_args.rst``, ``qwen35_deployment.rst``) to describe the single-knob design. - Update unit tests accordingly. Tests: 22/22 pass on the affected files.
Summary
Adds three layers of OOM protection for dynamic-resolution multimodal models (Qwen2.5/3/3.5-VL, qwen3-omni, tarsier2), analogous to the LLM side's
batch_max_tokens+max_req_total_lenpair, plus a heartbeat fix for the ViT infer worker.--visual_batch_max_tokens— per-step ViT admission budget measured in image output tokens (postspatial_merge).pull_batch_with_budgetin the ViT scheduler stops adding images once cumulativetoken_numwould exceed the budget. The first image is always admitted to avoid deadlock.--visual_image_max_tokens— per-image hard cap. Whenenable_multimodal, defaults tovisual_batch_max_tokens, which in turn defaults tobatch_max_tokens. Assertsvisual_image_max_tokens <= visual_batch_max_tokensat startup.clamp_processor_max_pixels— at tokenizer/ViT load time, tightens the HF image processor'smax_pixelssosmart_resizeauto-clamps any oversized image to fit the per-image token budget. Runtimetoken_num > budgetbecomes structurally impossible for Qwen-VL-family models. Treatsprocessor.max_pixels=None(e.g. Qwen3.5-VLAutoProcessor) as "looser than any budget".enforce_image_token_budgetremains as defense-in-depth in the HTTP managers, firing for InternVL and any non-processor-clamped path.Also fixes a heartbeat regression in the ViT infer worker: rank 0 used to block indefinitely on the infer queue while non-zero ranks waited on a gloo broadcast with a 30-minute timeout, crashing workers during idle periods. Rank 0 now polls with a 60s timeout and broadcasts an empty batch as a heartbeat so all ranks stay in sync.
Test plan
max_pixels=Noneregression