-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Support encoder input chunking for SALM vLLM inference #15716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,6 +64,17 @@ | |
| _SAMPLING_RATE = 16000 | ||
| _AUDIO_CHANNELS = 1 | ||
| _DUMMY_AUDIO_DURATION_S = 40.0 | ||
| _DUMMY_AUDIO_MAX_DURATION_S = 3600.0 | ||
| _DUMMY_AUDIO_TEXT_TOKEN_RESERVE = 64 | ||
| # FastConformer preprocessor hop length, used to derive the smallest | ||
| # chunk that produces ≥ 2 feature frames (per-feature normalization | ||
| # breaks on a single frame). Mirrors | ||
| # ``encoder_chunking._get_min_chunk_size_samples`` for the canonical | ||
| # preprocessor we ship; the chunking helper probes the live featurizer | ||
| # at training time, but the prompt processor here runs before the | ||
| # perception module is loaded, so we use the same constant the helper | ||
| # would derive. | ||
| _MIN_CHUNK_SIZE_SAMPLES = 320 | ||
|
|
||
|
|
||
| # ── Helpers ───────────────────────────────────────────────────────── | ||
|
|
@@ -123,11 +134,20 @@ def get_data_parser(self) -> MultiModalDataParser: | |
| ) | ||
|
|
||
| def get_supported_mm_limits(self) -> Mapping[str, int | None]: | ||
| return {"audio": 1} | ||
| return {"audio": None} | ||
|
|
||
| def _get_encoder_chunk_size_seconds(self) -> float | None: | ||
| """Return the per-encoder-call chunk size baked into the checkpoint. | ||
|
|
||
| Mirrors the training-time ``model.encoder_chunk_size_seconds`` field | ||
| (see ``encode_audio_with_optional_chunking``). ``None`` means the | ||
| encoder runs once over the full audio, matching legacy checkpoints. | ||
| """ | ||
| return getattr(self.get_hf_config(), "encoder_chunk_size_seconds", None) | ||
|
|
||
| @staticmethod | ||
| def _estimate_audio_tokens(audio_length_samples: int) -> int: | ||
| """Predict the encoder's output frame count for an audio of N samples. | ||
| def _estimate_audio_tokens_single_pass(audio_length_samples: int) -> int: | ||
| """Predict the encoder's output frame count for one perception forward. | ||
|
|
||
| Mirrors the FastConformer preprocessing chain used by | ||
| ``AudioPerceptionModule``: STFT (n_fft=512, hop_length=160) followed | ||
|
|
@@ -151,6 +171,64 @@ def _estimate_audio_tokens(audio_length_samples: int) -> int: | |
| length = (length + add_pad) / stride + 1.0 | ||
| return max(1, int(length)) | ||
|
|
||
| @classmethod | ||
| def _estimate_audio_tokens( | ||
| cls, | ||
| audio_length_samples: int, | ||
| chunk_size_seconds: float | None = None, | ||
| ) -> int: | ||
| """Predict the encoder's total output frame count for an audio of N samples. | ||
|
|
||
| When ``chunk_size_seconds`` is ``None`` or the audio fits in a single | ||
| chunk, returns the single-pass estimate. Otherwise mirrors | ||
| ``encode_audio_with_optional_chunking``'s split (with the same | ||
| tail-folding rule) and sums the per-chunk frame counts so the | ||
| placeholder count matches what the model emits at forward time. | ||
| """ | ||
| if chunk_size_seconds is None or audio_length_samples <= 0: | ||
| return cls._estimate_audio_tokens_single_pass(audio_length_samples) | ||
| if chunk_size_seconds <= 0.0: | ||
| raise ValueError("encoder_chunk_size_seconds must be positive when set.") | ||
| chunk_size_samples = max(1, int(round(chunk_size_seconds * _SAMPLING_RATE))) | ||
| chunk_size_samples = max(chunk_size_samples, _MIN_CHUNK_SIZE_SAMPLES) | ||
| if audio_length_samples <= chunk_size_samples: | ||
| return cls._estimate_audio_tokens_single_pass(audio_length_samples) | ||
|
|
||
| spans: list[tuple[int, int]] = [] | ||
| for begin in range(0, audio_length_samples, chunk_size_samples): | ||
| end = min(begin + chunk_size_samples, audio_length_samples) | ||
| spans.append((begin, end)) | ||
| if len(spans) > 1 and spans[-1][1] - spans[-1][0] < _MIN_CHUNK_SIZE_SAMPLES: | ||
| spans[-2] = (spans[-2][0], spans[-1][1]) | ||
| spans.pop() | ||
|
|
||
| return sum(cls._estimate_audio_tokens_single_pass(end - begin) for begin, end in spans) | ||
|
|
||
| @classmethod | ||
| def _samples_for_audio_tokens(cls, target_tokens: int, chunk_size_seconds: float | None = None) -> int: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """Return the smallest sample count estimated to produce ``target_tokens``. | ||
|
|
||
| vLLM sizes the multimodal encoder cache from dummy inputs. The SALM | ||
| plugin supports arbitrarily long audio by chunking the encoder forward, | ||
| but the decoder still receives the concatenated full-audio embedding | ||
| sequence. This inverse estimator lets ``--limit-mm-per-prompt`` audio | ||
| length hints reserve cache for that full sequence without hard-coding a | ||
| single maximum call duration. | ||
| """ | ||
| target_tokens = max(1, int(target_tokens)) | ||
| max_samples = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) | ||
| lo, hi = 1, min(_SAMPLING_RATE, max_samples) | ||
| while hi < max_samples and cls._estimate_audio_tokens(hi, chunk_size_seconds) < target_tokens: | ||
| hi = min(hi * 2, max_samples) | ||
|
|
||
| while lo < hi: | ||
| mid = (lo + hi) // 2 | ||
| if cls._estimate_audio_tokens(mid, chunk_size_seconds) >= target_tokens: | ||
| hi = mid | ||
| else: | ||
| lo = mid + 1 | ||
| return lo | ||
|
|
||
|
|
||
| class NeMoSpeechLMMultiModalProcessor( | ||
| BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo], | ||
|
|
@@ -182,10 +260,11 @@ def _get_prompt_updates( | |
| out_mm_kwargs: MultiModalKwargsItems, | ||
| ) -> list[PromptUpdate]: | ||
| audios = mm_items.get_items("audio", AudioProcessorItems) | ||
| chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() | ||
|
|
||
| def get_replacement(item_idx: int): | ||
| audio = audios.get(item_idx) | ||
| n_tokens = self.info._estimate_audio_tokens(audio.shape[-1]) | ||
| n_tokens = self.info._estimate_audio_tokens(audio.shape[-1], chunk_size_seconds) | ||
| repl_full = _AUDIO_PLACEHOLDER * n_tokens | ||
| return PromptUpdateDetails.select_text(repl_full, _AUDIO_PLACEHOLDER) | ||
|
|
||
|
|
@@ -210,6 +289,7 @@ def _call_hf_processor( | |
| audios = mm_data.pop("audios", []) | ||
|
|
||
| if audios: | ||
| chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() | ||
| audio_list: list[torch.Tensor] = [] | ||
| audio_lengths: list[int] = [] | ||
| parts = re.split(f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt) | ||
|
|
@@ -229,7 +309,7 @@ def _call_hf_processor( | |
| ) | ||
| if audio_tensor.dim() > 1: | ||
| audio_tensor = audio_tensor.squeeze() | ||
| n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1]) | ||
| n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1], chunk_size_seconds) | ||
| parts[i] = _AUDIO_PLACEHOLDER * n_tokens | ||
| audio_list.append(audio_tensor) | ||
| audio_lengths.append(audio_tensor.shape[-1]) | ||
|
|
@@ -257,6 +337,19 @@ def get_dummy_mm_data( | |
| ) -> MultiModalDataDict: | ||
| num_audios = mm_counts.get("audio", 0) | ||
| dummy_audio_len = int(_DUMMY_AUDIO_DURATION_S * _SAMPLING_RATE) | ||
| audio_options = mm_options.get("audio") if mm_options else None | ||
| requested_audio_len = getattr(audio_options, "length", None) | ||
| if requested_audio_len: | ||
| chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() | ||
| if seq_len > _DUMMY_AUDIO_TEXT_TOKEN_RESERVE: | ||
| max_audio_tokens = seq_len - _DUMMY_AUDIO_TEXT_TOKEN_RESERVE | ||
| max_audio_len = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens( | ||
| max_audio_tokens, | ||
| chunk_size_seconds, | ||
| ) | ||
| else: | ||
| max_audio_len = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) | ||
| dummy_audio_len = min(int(requested_audio_len), max_audio_len) | ||
| return { | ||
| "audio": self._get_dummy_audios( | ||
| length=dummy_audio_len, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -49,7 +49,9 @@ | |||||
| from vllm.multimodal import MULTIMODAL_REGISTRY | ||||||
| from vllm.sequence import IntermediateTensors | ||||||
|
|
||||||
| from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking | ||||||
| from nemo.collections.speechlm2.vllm.salm.audio import ( | ||||||
| _SAMPLING_RATE, | ||||||
| NeMoSpeechLMAudioInputs, | ||||||
| NeMoSpeechLMDummyInputsBuilder, | ||||||
| NeMoSpeechLMMultiModalProcessor, | ||||||
|
|
@@ -87,6 +89,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |||||
| super().__init__() | ||||||
| config = vllm_config.model_config.hf_config | ||||||
| self.config = config | ||||||
| self.encoder_chunk_size_seconds = getattr(config, "encoder_chunk_size_seconds", None) | ||||||
|
|
||||||
| backend = make_backend(config) | ||||||
| self._backend = backend | ||||||
|
|
@@ -142,13 +145,21 @@ def _process_audio(self, audio_input: NeMoSpeechLMAudioInputs) -> tuple[torch.Te | |||||
| audio_signal = audio_signal.to(device=device, dtype=_AUDIO_INPUT_DTYPE) | ||||||
| audio_lengths = audio_input.audio_signal_length.to(device=device) | ||||||
|
|
||||||
| # Mirrors training (``encode_audio_with_optional_chunking``): when the | ||||||
| # checkpoint was trained with a chunked encoder (e.g. SALMAutomodel | ||||||
| # default 30 s), long audios are split into chunks before the perception | ||||||
| # forward and the per-chunk embeddings are concatenated. ``None`` | ||||||
| # disables chunking and runs a single forward over the full batch. | ||||||
| with torch.no_grad(): | ||||||
| audio_embeds, audio_embed_lens = self.perception( | ||||||
| input_signal=audio_signal, | ||||||
| input_signal_length=audio_lengths, | ||||||
| audio_embeds = encode_audio_with_optional_chunking( | ||||||
| self.perception, | ||||||
| audio_signal, | ||||||
| audio_lengths, | ||||||
| chunk_size_seconds=self.encoder_chunk_size_seconds, | ||||||
| sampling_rate=_SAMPLING_RATE, | ||||||
| ) | ||||||
|
|
||||||
| return tuple(audio_embeds[i, : audio_embed_lens[i]] for i in range(audio_embeds.shape[0])) | ||||||
| return tuple(emb.to(torch.bfloat16) for emb in audio_embeds) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
|
||||||
|
|
||||||
| def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: | ||||||
| audio_input = self._parse_audio_input(**kwargs) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,3 +80,53 @@ def test_estimator_matches_calc_length(samples: int) -> None: | |
| def test_estimator_min_one() -> None: | ||
| """Even for very short audio the estimator must return at least 1.""" | ||
| assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(1) >= 1 | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we also need test for |
||
|
|
||
| def test_estimator_chunking_disabled_matches_single_pass() -> None: | ||
| """``chunk_size_seconds=None`` must match the legacy single-pass estimate.""" | ||
| samples = 30 * 16_000 | ||
| assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens( | ||
| samples, chunk_size_seconds=None | ||
| ) == NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) | ||
|
|
||
|
|
||
| def test_estimator_short_audio_falls_back_to_single_pass() -> None: | ||
| """Audio shorter than the chunk size collapses to a single forward.""" | ||
| samples = 5 * 16_000 | ||
| assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens( | ||
| samples, chunk_size_seconds=30.0 | ||
| ) == NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) | ||
|
|
||
|
|
||
| def test_estimator_chunked_sums_per_chunk_frames() -> None: | ||
| """Long audio is split into chunks and per-chunk frame counts are summed, | ||
| matching ``encode_audio_with_optional_chunking``'s concat behavior.""" | ||
| samples = 90 * 16_000 | ||
| chunk_size_seconds = 30.0 | ||
| chunk_samples = int(round(chunk_size_seconds * 16_000)) | ||
| expected = sum( | ||
| NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(min(chunk_samples, samples - i)) | ||
| for i in range(0, samples, chunk_samples) | ||
| ) | ||
| assert ( | ||
| NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected | ||
| ) | ||
|
|
||
|
|
||
| def test_estimator_chunked_tail_folded_into_previous_chunk() -> None: | ||
| """A tiny tail (< min chunk size) is folded into the previous chunk so | ||
| the total token count matches the runtime helper instead of producing a | ||
| spurious single-frame chunk that the audio preprocessor would reject.""" | ||
| chunk_size_seconds = 30.0 | ||
| chunk_samples = int(round(chunk_size_seconds * 16_000)) | ||
| samples = chunk_samples + 100 # 100 sample tail < min_chunk_size_samples (320) | ||
| # Folded: one chunk of `samples` samples (no split). | ||
| expected = NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) | ||
| assert ( | ||
| NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected | ||
| ) | ||
|
|
||
|
|
||
| def test_estimator_negative_chunk_size_raises() -> None: | ||
| with pytest.raises(ValueError, match="encoder_chunk_size_seconds"): | ||
| NeMoSpeechLMProcessingInfo._estimate_audio_tokens(16_000, chunk_size_seconds=-1.0) | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
len(spans) > 1is redundant here. But we can keep it for now.