Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 98 additions & 5 deletions nemo/collections/speechlm2/vllm/salm/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@
_SAMPLING_RATE = 16000
_AUDIO_CHANNELS = 1
_DUMMY_AUDIO_DURATION_S = 40.0
_DUMMY_AUDIO_MAX_DURATION_S = 3600.0
_DUMMY_AUDIO_TEXT_TOKEN_RESERVE = 64
# FastConformer preprocessor hop length, used to derive the smallest
# chunk that produces ≥ 2 feature frames (per-feature normalization
# breaks on a single frame). Mirrors
# ``encoder_chunking._get_min_chunk_size_samples`` for the canonical
# preprocessor we ship; the chunking helper probes the live featurizer
# at training time, but the prompt processor here runs before the
# perception module is loaded, so we use the same constant the helper
# would derive.
_MIN_CHUNK_SIZE_SAMPLES = 320


# ── Helpers ─────────────────────────────────────────────────────────
Expand Down Expand Up @@ -123,11 +134,20 @@ def get_data_parser(self) -> MultiModalDataParser:
)

def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
return {"audio": None}

def _get_encoder_chunk_size_seconds(self) -> float | None:
"""Return the per-encoder-call chunk size baked into the checkpoint.

Mirrors the training-time ``model.encoder_chunk_size_seconds`` field
(see ``encode_audio_with_optional_chunking``). ``None`` means the
encoder runs once over the full audio, matching legacy checkpoints.
"""
return getattr(self.get_hf_config(), "encoder_chunk_size_seconds", None)

@staticmethod
def _estimate_audio_tokens(audio_length_samples: int) -> int:
"""Predict the encoder's output frame count for an audio of N samples.
def _estimate_audio_tokens_single_pass(audio_length_samples: int) -> int:
"""Predict the encoder's output frame count for one perception forward.

Mirrors the FastConformer preprocessing chain used by
``AudioPerceptionModule``: STFT (n_fft=512, hop_length=160) followed
Expand All @@ -151,6 +171,64 @@ def _estimate_audio_tokens(audio_length_samples: int) -> int:
length = (length + add_pad) / stride + 1.0
return max(1, int(length))

@classmethod
def _estimate_audio_tokens(
cls,
audio_length_samples: int,
chunk_size_seconds: float | None = None,
) -> int:
"""Predict the encoder's total output frame count for an audio of N samples.

When ``chunk_size_seconds`` is ``None`` or the audio fits in a single
chunk, returns the single-pass estimate. Otherwise mirrors
``encode_audio_with_optional_chunking``'s split (with the same
tail-folding rule) and sums the per-chunk frame counts so the
placeholder count matches what the model emits at forward time.
"""
if chunk_size_seconds is None or audio_length_samples <= 0:
return cls._estimate_audio_tokens_single_pass(audio_length_samples)
if chunk_size_seconds <= 0.0:
raise ValueError("encoder_chunk_size_seconds must be positive when set.")
chunk_size_samples = max(1, int(round(chunk_size_seconds * _SAMPLING_RATE)))
chunk_size_samples = max(chunk_size_samples, _MIN_CHUNK_SIZE_SAMPLES)
if audio_length_samples <= chunk_size_samples:
return cls._estimate_audio_tokens_single_pass(audio_length_samples)

spans: list[tuple[int, int]] = []
for begin in range(0, audio_length_samples, chunk_size_samples):
end = min(begin + chunk_size_samples, audio_length_samples)
spans.append((begin, end))
if len(spans) > 1 and spans[-1][1] - spans[-1][0] < _MIN_CHUNK_SIZE_SAMPLES:
Copy link
Copy Markdown
Contributor

@DongjiGao DongjiGao May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think len(spans) > 1 is redundant here. But we can keep it for now.

spans[-2] = (spans[-2][0], spans[-1][1])
spans.pop()

return sum(cls._estimate_audio_tokens_single_pass(end - begin) for begin, end in spans)

@classmethod
def _samples_for_audio_tokens(cls, target_tokens: int, chunk_size_seconds: float | None = None) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_samples_for_audio_tokens() can return the max-duration cap even when that audio is still too short to produce target_tokens (e.g., if max_model_len is very large)

"""Return the smallest sample count estimated to produce ``target_tokens``.

vLLM sizes the multimodal encoder cache from dummy inputs. The SALM
plugin supports arbitrarily long audio by chunking the encoder forward,
but the decoder still receives the concatenated full-audio embedding
sequence. This inverse estimator lets ``--limit-mm-per-prompt`` audio
length hints reserve cache for that full sequence without hard-coding a
single maximum call duration.
"""
target_tokens = max(1, int(target_tokens))
max_samples = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE)
lo, hi = 1, min(_SAMPLING_RATE, max_samples)
while hi < max_samples and cls._estimate_audio_tokens(hi, chunk_size_seconds) < target_tokens:
hi = min(hi * 2, max_samples)

while lo < hi:
mid = (lo + hi) // 2
if cls._estimate_audio_tokens(mid, chunk_size_seconds) >= target_tokens:
hi = mid
else:
lo = mid + 1
return lo


class NeMoSpeechLMMultiModalProcessor(
BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo],
Expand Down Expand Up @@ -182,10 +260,11 @@ def _get_prompt_updates(
out_mm_kwargs: MultiModalKwargsItems,
) -> list[PromptUpdate]:
audios = mm_items.get_items("audio", AudioProcessorItems)
chunk_size_seconds = self.info._get_encoder_chunk_size_seconds()

def get_replacement(item_idx: int):
audio = audios.get(item_idx)
n_tokens = self.info._estimate_audio_tokens(audio.shape[-1])
n_tokens = self.info._estimate_audio_tokens(audio.shape[-1], chunk_size_seconds)
repl_full = _AUDIO_PLACEHOLDER * n_tokens
return PromptUpdateDetails.select_text(repl_full, _AUDIO_PLACEHOLDER)

Expand All @@ -210,6 +289,7 @@ def _call_hf_processor(
audios = mm_data.pop("audios", [])

if audios:
chunk_size_seconds = self.info._get_encoder_chunk_size_seconds()
audio_list: list[torch.Tensor] = []
audio_lengths: list[int] = []
parts = re.split(f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt)
Expand All @@ -229,7 +309,7 @@ def _call_hf_processor(
)
if audio_tensor.dim() > 1:
audio_tensor = audio_tensor.squeeze()
n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1])
n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1], chunk_size_seconds)
parts[i] = _AUDIO_PLACEHOLDER * n_tokens
audio_list.append(audio_tensor)
audio_lengths.append(audio_tensor.shape[-1])
Expand Down Expand Up @@ -257,6 +337,19 @@ def get_dummy_mm_data(
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
dummy_audio_len = int(_DUMMY_AUDIO_DURATION_S * _SAMPLING_RATE)
audio_options = mm_options.get("audio") if mm_options else None
requested_audio_len = getattr(audio_options, "length", None)
if requested_audio_len:
chunk_size_seconds = self.info._get_encoder_chunk_size_seconds()
if seq_len > _DUMMY_AUDIO_TEXT_TOKEN_RESERVE:
max_audio_tokens = seq_len - _DUMMY_AUDIO_TEXT_TOKEN_RESERVE
max_audio_len = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens(
max_audio_tokens,
chunk_size_seconds,
)
else:
max_audio_len = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE)
dummy_audio_len = min(int(requested_audio_len), max_audio_len)
return {
"audio": self._get_dummy_audios(
length=dummy_audio_len,
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/speechlm2/vllm/salm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
prompt_format: str | None = None,
pretrained_weights: bool | None = None,
lora: dict | None = None,
encoder_chunk_size_seconds: float | None = None,
**kwargs,
):
required_fields = {
Expand All @@ -88,6 +89,7 @@ def __init__(
is_default_init = (
perception is None
and lora is None
and encoder_chunk_size_seconds is None
and not kwargs
and all(value is None for value in required_fields.values())
)
Expand All @@ -112,6 +114,7 @@ def __init__(
self.prompt_format = None
self.pretrained_weights = None
self.lora = None
self.encoder_chunk_size_seconds = None
return

for name, value in required_fields.items():
Expand All @@ -137,6 +140,7 @@ def __init__(
self.prompt_format = prompt_format
self.pretrained_weights = pretrained_weights
self.lora = lora
self.encoder_chunk_size_seconds = encoder_chunk_size_seconds

self.text_config = AutoConfig.from_pretrained(pretrained_llm, trust_remote_code=True)

Expand Down Expand Up @@ -214,6 +218,7 @@ def __getattr__(self, name):
"text_config",
"lora",
"is_hybrid",
"encoder_chunk_size_seconds",
):
raise AttributeError(name)
alias = self._ATTR_ALIASES.get(name, name) if self.is_hybrid else name
Expand Down
19 changes: 15 additions & 4 deletions nemo/collections/speechlm2/vllm/salm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors

from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking
from nemo.collections.speechlm2.vllm.salm.audio import (
_SAMPLING_RATE,
NeMoSpeechLMAudioInputs,
NeMoSpeechLMDummyInputsBuilder,
NeMoSpeechLMMultiModalProcessor,
Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.encoder_chunk_size_seconds = getattr(config, "encoder_chunk_size_seconds", None)

backend = make_backend(config)
self._backend = backend
Expand Down Expand Up @@ -142,13 +145,21 @@ def _process_audio(self, audio_input: NeMoSpeechLMAudioInputs) -> tuple[torch.Te
audio_signal = audio_signal.to(device=device, dtype=_AUDIO_INPUT_DTYPE)
audio_lengths = audio_input.audio_signal_length.to(device=device)

# Mirrors training (``encode_audio_with_optional_chunking``): when the
# checkpoint was trained with a chunked encoder (e.g. SALMAutomodel
# default 30 s), long audios are split into chunks before the perception
# forward and the per-chunk embeddings are concatenated. ``None``
# disables chunking and runs a single forward over the full batch.
with torch.no_grad():
audio_embeds, audio_embed_lens = self.perception(
input_signal=audio_signal,
input_signal_length=audio_lengths,
audio_embeds = encode_audio_with_optional_chunking(
self.perception,
audio_signal,
audio_lengths,
chunk_size_seconds=self.encoder_chunk_size_seconds,
sampling_rate=_SAMPLING_RATE,
)

return tuple(audio_embeds[i, : audio_embed_lens[i]] for i in range(audio_embeds.shape[0]))
return tuple(emb.to(torch.bfloat16) for emb in audio_embeds)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
return tuple(emb.to(torch.bfloat16) for emb in audio_embeds)
return tuple(emb.to(_PERCEPTION_DTYPE) for emb in audio_embeds)


def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings:
audio_input = self._parse_audio_input(**kwargs)
Expand Down
50 changes: 50 additions & 0 deletions tests/collections/speechlm2/test_vllm_audio_token_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,53 @@ def test_estimator_matches_calc_length(samples: int) -> None:
def test_estimator_min_one() -> None:
"""Even for very short audio the estimator must return at least 1."""
assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(1) >= 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need test for _samples_for_audio_tokens(), audio sizing when audio.length is provided in mm_options, and _MIN_CHUNK_SIZE_SAMPLES = 320


def test_estimator_chunking_disabled_matches_single_pass() -> None:
"""``chunk_size_seconds=None`` must match the legacy single-pass estimate."""
samples = 30 * 16_000
assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(
samples, chunk_size_seconds=None
) == NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples)


def test_estimator_short_audio_falls_back_to_single_pass() -> None:
"""Audio shorter than the chunk size collapses to a single forward."""
samples = 5 * 16_000
assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(
samples, chunk_size_seconds=30.0
) == NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples)


def test_estimator_chunked_sums_per_chunk_frames() -> None:
"""Long audio is split into chunks and per-chunk frame counts are summed,
matching ``encode_audio_with_optional_chunking``'s concat behavior."""
samples = 90 * 16_000
chunk_size_seconds = 30.0
chunk_samples = int(round(chunk_size_seconds * 16_000))
expected = sum(
NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(min(chunk_samples, samples - i))
for i in range(0, samples, chunk_samples)
)
assert (
NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected
)


def test_estimator_chunked_tail_folded_into_previous_chunk() -> None:
"""A tiny tail (< min chunk size) is folded into the previous chunk so
the total token count matches the runtime helper instead of producing a
spurious single-frame chunk that the audio preprocessor would reject."""
chunk_size_seconds = 30.0
chunk_samples = int(round(chunk_size_seconds * 16_000))
samples = chunk_samples + 100 # 100 sample tail < min_chunk_size_samples (320)
# Folded: one chunk of `samples` samples (no split).
expected = NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples)
assert (
NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected
)


def test_estimator_negative_chunk_size_raises() -> None:
with pytest.raises(ValueError, match="encoder_chunk_size_seconds"):
NeMoSpeechLMProcessingInfo._estimate_audio_tokens(16_000, chunk_size_seconds=-1.0)
26 changes: 24 additions & 2 deletions tests/collections/speechlm2/test_vllm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,26 @@ def test_unknown_attr_raises(self):
with pytest.raises(AttributeError):
_ = cfg.nonexistent_attribute_xyz

def test_encoder_chunk_size_seconds_default_none(self):
"""Legacy checkpoints without a chunk size keep the single-pass encoder path."""
cfg = NeMoSpeechLMConfig(**_DEFAULT_CONFIG_KWARGS)
assert cfg.encoder_chunk_size_seconds is None

def test_encoder_chunk_size_seconds_round_trips(self):
"""Chunk size set in config.json (e.g. SALMAutomodel default 30 s) survives load."""
cfg = NeMoSpeechLMConfig(
**{
**_DEFAULT_CONFIG_KWARGS,
"encoder_chunk_size_seconds": 30.0,
}
)
assert cfg.encoder_chunk_size_seconds == 30.0

def test_encoder_chunk_size_seconds_default_init_inert(self):
"""No-arg default init must still expose ``encoder_chunk_size_seconds=None``."""
cfg = NeMoSpeechLMConfig()
assert cfg.encoder_chunk_size_seconds is None


@pytest.mark.skipif(not (_HAS_CONFIG and _HAS_VLLM), reason="NeMoSpeechLMConfig or vLLM not available")
class TestBackendSelection:
Expand Down Expand Up @@ -332,7 +352,8 @@ def test_call_hf_processor_requires_matching_placeholder_count(self):
processor = object.__new__(NeMoSpeechLMMultiModalProcessor)
processor.info = SimpleNamespace(
get_tokenizer=_FakeTokenizer,
_estimate_audio_tokens=lambda samples: 2,
_estimate_audio_tokens=lambda samples, chunk_size_seconds=None: 2,
_get_encoder_chunk_size_seconds=lambda: None,
)

with pytest.raises(ValueError, match="placeholders"):
Expand All @@ -351,7 +372,8 @@ def test_call_hf_processor_emits_true_audio_lengths(self):
processor = object.__new__(NeMoSpeechLMMultiModalProcessor)
processor.info = SimpleNamespace(
get_tokenizer=_FakeTokenizer,
_estimate_audio_tokens=lambda samples: 2,
_estimate_audio_tokens=lambda samples, chunk_size_seconds=None: 2,
_get_encoder_chunk_size_seconds=lambda: None,
)

result = processor._call_hf_processor(
Expand Down
Loading