Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions tensorrt_llm/_torch/models/modeling_nemotron_nano.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved.
import copy
import math
import os
Expand Down Expand Up @@ -2169,19 +2169,8 @@ def _prepare_audio_features(

expanded_text = self._expand_audio_placeholders(text, audios, extractor)

audio_inputs = extractor(
audios,
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
audio_data = {
"input_audio_features": audio_inputs.input_features,
"feature_attention_mask": audio_inputs.attention_mask,
}
# audio_num_clips records how many clips each audio stream was split
# into. Needed to regroup per-clip embeddings back to per-video.
audio_data["audio_num_clips"] = audio_inputs.audio_num_clips
return expanded_text, audio_data
audio_inputs = extractor(audios)
return expanded_text, audio_inputs

def _process_audio(
self,
Expand Down
237 changes: 205 additions & 32 deletions tensorrt_llm/_torch/models/modeling_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,138 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cache
from typing import Dict, NamedTuple, Optional

import numpy as np
import torch
import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder
from transformers import ParakeetEncoderConfig, ParakeetFeatureExtractor, PretrainedConfig
from transformers import ParakeetEncoderConfig, PretrainedConfig
from transformers.audio_utils import mel_filter_bank

from ...logger import logger
from ..modules.rms_norm import RMSNorm

EPSILON = 1e-5
LOG_ZERO_GUARD_VALUE = 2**-24

class ParakeetExtractor(ParakeetFeatureExtractor):

class ParakeetExtractor:
def __init__(self, config: PretrainedConfig) -> None:
self.config = _ExtractorConfig(
feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate,
subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride,
# Keep `self.config` as the single source of truth because
# `HFParakeetEncoder._get_subsampling_output_length` reads subsampling fields from it.
self.config = _ExtractorConfig.from_hf_config(config)

self._clip_target_samples = round(self.config.clip_duration_s * self.sampling_rate)
self._tail_min_samples = round(self.config.clip_min_duration_s * self.sampling_rate)

@property
def sampling_rate(self) -> int:
return self.config.sampling_rate

@staticmethod
@cache
def _get_window(win_length: int, device: str) -> torch.Tensor:
return torch.hann_window(win_length, periodic=False, device=device)

@staticmethod
@cache
def _get_mel_filters(
feature_size: int, sampling_rate: int, n_fft: int, device: str
) -> torch.Tensor:
filter_bank = mel_filter_bank(
num_frequency_bins=n_fft // 2 + 1,
num_mel_filters=feature_size,
min_frequency=0.0,
max_frequency=sampling_rate / 2,
sampling_rate=sampling_rate,
norm="slaney",
mel_scale="slaney",
)
super().__init__(**self.config._asdict())
return torch.from_numpy(filter_bank.T).to(device=device, dtype=torch.float32)

self._clip_target_samples = int(round(self.config.clip_duration_s * self.sampling_rate))
self._tail_min_samples = int(round(self.config.clip_min_duration_s * self.sampling_rate))
def _torch_extract_fbank_features(
self, waveform: torch.Tensor, device: str | torch.device
) -> torch.Tensor:
device = str(torch.device(device))
cfg = self.config
window = self._get_window(cfg.win_length, device)
stft = torch.stft(
waveform,
cfg.n_fft,
hop_length=cfg.hop_length,
win_length=cfg.win_length,
window=window,
return_complex=True,
pad_mode="constant",
)
mel_filters = self._get_mel_filters(cfg.feature_size, cfg.sampling_rate, cfg.n_fft, device)
return self._apply_mel_filters(stft, mel_filters)

@torch.compile(dynamic=True)
def _apply_mel_filters(
self, stft_output: torch.Tensor, mel_filters: torch.Tensor
) -> torch.Tensor:
magnitudes = stft_output.real.square() + stft_output.imag.square()
mel_spec = mel_filters @ magnitudes
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
return mel_spec.permute(0, 2, 1)

@torch.compile(dynamic=True)
def _apply_preemphasis(
self, input_features: torch.Tensor, audio_lengths: torch.Tensor
) -> torch.Tensor:
preemphasis = self.config.preemphasis
if preemphasis is None:
return input_features
timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze(
0
) < audio_lengths.unsqueeze(1)
input_features = torch.cat(
[
input_features[:, :1],
input_features[:, 1:] - preemphasis * input_features[:, :-1],
],
dim=1,
)
return input_features.masked_fill(~timemask, 0.0)

@torch.compile(dynamic=True)
def _normalize_mel_features(
self, mel_features: torch.Tensor, audio_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
features_lengths = torch.floor_divide(
audio_lengths + self.config.n_fft // 2 * 2 - self.config.n_fft,
self.config.hop_length,
)
attention_mask = (
torch.arange(mel_features.shape[1], device=mel_features.device)[None, :]
< features_lengths[:, None]
)
mask = attention_mask.unsqueeze(-1)
lengths = attention_mask.sum(dim=1)
mel_features_masked = mel_features * mask
mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1)
variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (lengths - 1).unsqueeze(
-1
)
std = torch.sqrt(variance).unsqueeze(1)
return (mel_features - mean) / (std + EPSILON) * mask, attention_mask

def _pad_raw_speech(
self, raw_speech: list[torch.Tensor], max_len: int, device: str | torch.device
) -> torch.Tensor:
output = torch.full(
(len(raw_speech), max_len),
self.config.padding_value,
device=device,
dtype=torch.float32,
)
dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))]
srcs = [speech.squeeze(-1) for speech in raw_speech]
torch._foreach_copy_(dsts, srcs)
return output

def _clip_sizes(self, audio_len: int) -> list[int]:
audio_len = max(audio_len, self._tail_min_samples)
Expand All @@ -48,59 +156,124 @@ def _clip_sizes(self, audio_len: int) -> list[int]:

def audio_token_count(self, audio_len: int) -> int:
clip_sizes = self._clip_sizes(audio_len)
num_frames = torch.tensor([cs // self.hop_length for cs in clip_sizes], dtype=torch.float)
num_frames = torch.tensor(
[cs // self.config.hop_length for cs in clip_sizes], dtype=torch.float
)
# NOTE: this is a massive hack in order not to duplicate the functionality here.
n_tokens = HFParakeetEncoder._get_subsampling_output_length(self, num_frames)
return max(1, int(n_tokens.sum().item()))

def _split_audio_into_clips(self, audio: np.ndarray) -> list[np.ndarray]:
if audio.ndim == 2:
if audio.shape[1] == 0:
def _to_mono_tensor(
self, audio: np.ndarray | torch.Tensor, device: str | torch.device
) -> torch.Tensor:
audio_tensor = torch.as_tensor(audio, device=device, dtype=torch.float32)
if audio_tensor.ndim == 2:
if audio_tensor.shape[1] == 0:
raise ValueError(
f"Unsupported audio shape {audio.shape}: expected at least one channel"
f"Unsupported audio shape {audio_tensor.shape}: expected at least one channel"
)
audio = audio.mean(axis=1)
elif audio.ndim != 1:
logger.warning(
f"Only mono-channel audio is supported for input to {self.__class__.__name__} "
"We will take the mean of the channels to convert to mono."
)
audio_tensor = audio_tensor.mean(dim=-1)
elif audio_tensor.ndim != 1:
raise ValueError(
f"Unsupported audio shape {audio.shape}: "
f"Unsupported audio shape {audio_tensor.shape}: "
"expected 1-D (mono) or 2-D (samples x channels)"
)
return audio_tensor

def split_audio_into_clips(self, audio: torch.Tensor) -> list[torch.Tensor]:
if audio.ndim != 1:
raise ValueError(f"Unsupported audio shape {audio.shape}: expected 1-D mono audio")
audio_len = int(audio.shape[0])
clip_sizes = self._clip_sizes(audio_len)
target_len = sum(clip_sizes)
if audio_len < target_len:
audio = np.pad(audio, (0, target_len - audio_len))
audio = torch.nn.functional.pad(audio, (0, target_len - audio_len))

clips: list[torch.Tensor] = []
offset = 0
for clip_size in clip_sizes:
clips.append(audio[offset : offset + clip_size])
offset += clip_size
return clips

def _split_audio_into_clips(self, audio: np.ndarray | torch.Tensor) -> list[torch.Tensor]:
audio_tensor = self._to_mono_tensor(audio, "cpu")
return self.split_audio_into_clips(audio_tensor)

split_indices = np.cumsum(clip_sizes[:-1])
return np.split(audio, split_indices)
def __call__(
self,
raw_speech: list[np.ndarray | torch.Tensor],
*,
device: str | torch.device = "cpu",
) -> dict[str, torch.Tensor]:
if len(raw_speech) == 0:
raise ValueError("raw_speech must contain at least one audio array.")

def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs) -> torch.Tensor:
audio_clips = list[np.ndarray]()
audio_num_clips = list[int]()
raw_speech = [self._to_mono_tensor(audio, device) for audio in raw_speech]
audio_clips: list[torch.Tensor] = []
audio_num_clips: list[int] = []
for audio in raw_speech:
clips = self._split_audio_into_clips(audio)
clips = self.split_audio_into_clips(audio)
audio_clips.extend(clips)
audio_num_clips.append(len(clips))

outputs = super().__call__(audio_clips, *args, **kwargs)
outputs["audio_num_clips"] = torch.tensor(audio_num_clips, dtype=torch.long)
return outputs
audio_lengths = torch.tensor(
[len(speech) for speech in audio_clips],
dtype=torch.long,
device=device,
)
max_length = max(len(speech) for speech in audio_clips)
input_features = self._pad_raw_speech(audio_clips, max_length, device)
if self.config.preemphasis is not None:
input_features = self._apply_preemphasis(input_features, audio_lengths)
input_features = self._torch_extract_fbank_features(input_features, device)
input_features, attention_mask = self._normalize_mel_features(input_features, audio_lengths)

return {
"input_audio_features": input_features,
"feature_attention_mask": attention_mask,
"audio_num_clips": torch.tensor(audio_num_clips, dtype=torch.long, device=device),
}

def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)
return int(audio_tokens * self.config.subsampling_factor * self.config.hop_length)


# The sole purpose of this config object is so that we are able to make the call to
# `HFParakeetEncoder._get_subsampling_output_length`, which just reads a few of the below values.
# This config is the extractor's single source of truth. It also supplies the fields read by
# `HFParakeetEncoder._get_subsampling_output_length`.
class _ExtractorConfig(NamedTuple):
feature_size: int
sampling_rate: int
subsampling_factor: int
subsampling_conv_kernel_size: int
subsampling_conv_stride: int
hop_length: int = 160
win_length: int = 400
preemphasis: float | None = 0.97
n_fft: int = 512
padding_value: float = 0.0
clip_duration_s: int = 30
clip_min_duration_s: float = 0.1

@classmethod
def from_hf_config(cls, config: PretrainedConfig) -> "_ExtractorConfig":
return cls(
feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate,
subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride,
hop_length=getattr(config, "hop_length", cls._field_defaults["hop_length"]),
win_length=getattr(config, "win_length", cls._field_defaults["win_length"]),
preemphasis=getattr(config, "preemphasis", cls._field_defaults["preemphasis"]),
n_fft=getattr(config, "n_fft", cls._field_defaults["n_fft"]),
padding_value=getattr(config, "padding_value", cls._field_defaults["padding_value"]),
)


def _make_parakeet_encoder_config(
sound_config: PretrainedConfig,
Expand Down
43 changes: 40 additions & 3 deletions tests/unittest/_torch/modeling/test_modeling_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,50 @@ def test_split_audio_into_clips_preserves_total_samples(self):
def test_call_returns_expected_keys(self):
ext = _make_extractor()
audios = [np.random.randn(16000).astype(np.float32)]
result = ext(audios, sampling_rate=16000, return_tensors="pt")
assert "input_features" in result
assert "attention_mask" in result
result = ext(audios)
assert "input_audio_features" in result
assert "feature_attention_mask" in result
assert "audio_num_clips" in result
assert result["input_audio_features"].ndim == 3
assert result["input_audio_features"].shape[0] == 1
assert result["input_audio_features"].shape[2] == ext.config.feature_size
assert result["feature_attention_mask"].shape == result["input_audio_features"].shape[:2]
assert result["audio_num_clips"].shape == (1,)
assert result["audio_num_clips"].item() >= 1

def test_call_accepts_stereo_audio(self):
ext = _make_extractor()
audio = np.random.randn(16000, 2).astype(np.float32)
result = ext([audio])
assert result["input_audio_features"].shape[0] == 1
assert result["audio_num_clips"].item() == 1

def test_config_overrides(self):
hop_length = 80
win_length = 320
preemphasis = 0.5
n_fft = 256
padding_value = -1.0

ext = _make_extractor(
hop_length=hop_length,
win_length=win_length,
preemphasis=preemphasis,
n_fft=n_fft,
padding_value=padding_value,
)
assert ext.config.hop_length == hop_length
assert ext.config.win_length == win_length
assert ext.config.preemphasis == preemphasis
assert ext.config.n_fft == n_fft
assert ext.config.padding_value == padding_value

def test_sampling_rate_property(self):
ext = _make_extractor(sampling_rate=8000)
assert ext.sampling_rate == 8000
with pytest.raises(AttributeError):
ext.sampling_rate = 16000


class TestProjectedParakeet:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
Expand Down
Loading