Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def torch_impl(
)
scores, labels = logits[:, :-num_durations].max(dim=-1)

logits_with_fusion = None
if self.has_fusion_models():
fusion_scores_list, fusion_states_candidates_list = [], []
logits_with_fusion = logits.clone()
Expand All @@ -485,6 +486,22 @@ def torch_impl(

# get max scores and labels without blank
fusion_scores_max, fusion_labels_max = logits_with_fusion[:, : -num_durations - 1].max(dim=-1)

if self.preserve_alignments:
# The following code applies the fusion logic to the logits_with_fusion tensor.
# The goal is to ensure that the labels are coherent with the logits.

blank_is_best_without_fusion = labels == self._blank_index

# If blank is best without fusion, use original logits for that sample
if blank_is_best_without_fusion.any():
logits_with_fusion[blank_is_best_without_fusion] = logits[blank_is_best_without_fusion]

# If blank is NOT best without fusion, use fused logits but set blank to -inf to ensure that the blank is not selected.
non_blank_is_best = ~blank_is_best_without_fusion
if non_blank_is_best.any():
logits_with_fusion[non_blank_is_best, self._blank_index] = float('-inf')

# preserve "blank" / "non-blank" category
torch.where(labels == self._blank_index, labels, fusion_labels_max, out=labels)
torch.where(labels == self._blank_index, scores, fusion_scores_max, out=scores)
Expand All @@ -499,10 +516,11 @@ def torch_impl(
durations.masked_fill_(torch.logical_and(durations == 0, blank_mask), 1)
time_indices_current_labels.copy_(time_indices)
if use_alignments:
logits_alignments = logits_with_fusion if logits_with_fusion is not None else logits
alignments.add_results_masked_(
active_mask=active_mask,
time_indices=time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
logits=logits_alignments if self.preserve_alignments else None,
labels=labels if self.preserve_alignments else None,
confidence=self._get_frame_confidence(logits=logits, num_durations=num_durations),
)
Expand Down Expand Up @@ -531,6 +549,7 @@ def torch_impl(
# labels[advance_mask] are blank, and we are looking for non-blank labels
more_scores, more_labels = logits[:, :-num_durations].max(dim=-1)

logits_with_fusion = None
if self.has_fusion_models():
logits_with_fusion = logits.clone()
for fusion_scores in fusion_scores_list:
Expand All @@ -540,6 +559,22 @@ def torch_impl(
more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, : -num_durations - 1].max(
dim=-1
)

if self.preserve_alignments:
# The following code applies the fusion logic to the logits_with_fusion tensor.
# The goal is to ensure that the labels are coherent with the logits.

blank_is_best_without_fusion = more_labels == self._blank_index

# If blank is best without fusion, use original logits for that sample
if blank_is_best_without_fusion.any():
logits_with_fusion[blank_is_best_without_fusion] = logits[blank_is_best_without_fusion]

# If blank is NOT best without fusion, use fused logits but set blank to -inf
non_blank_is_best = ~blank_is_best_without_fusion
if non_blank_is_best.any():
logits_with_fusion[non_blank_is_best, self._blank_index] = float('-inf')

# preserve "blank" / "non-blank" category
torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels)

Expand All @@ -551,10 +586,11 @@ def torch_impl(
durations = model_durations[jump_durations_indices]

if use_alignments:
logits_alignments = logits_with_fusion if logits_with_fusion is not None else logits
alignments.add_results_masked_(
active_mask=advance_mask,
time_indices=time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
logits=logits_alignments if self.preserve_alignments else None,
labels=more_labels if self.preserve_alignments else None,
confidence=self._get_frame_confidence(logits=logits, num_durations=num_durations),
)
Expand Down
83 changes: 82 additions & 1 deletion tests/collections/asr/decoding/test_rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint
Expand Down Expand Up @@ -520,6 +520,87 @@ def test_tdt_greedy_decoding(
enable_per_stream_biasing=enable_per_stream_biasing,
)

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.with_downloads
@pytest.mark.unit
@pytest.mark.parametrize("use_lm", [True, False])
@pytest.mark.parametrize("use_boosting_tree", [True, False])
def test_tdt_greedy_decoding_preserve_alignments_with_fusion(
self,
test_data_dir,
use_lm: bool,
use_boosting_tree: bool,
):
"""
Test that alignments are correctly preserved when using TDT greedy decoding
with fusion models (LM and/or boosting tree).

This test ensures that the labels are coherent with the logits when using fusion models.
Meaning that the label should match the argmax of the vocab logits (excluding duration logits at the end).
"""
if not use_lm and not use_boosting_tree:
pytest.skip("At least one fusion model must be enabled for this test")

model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'nvidia/parakeet-tdt_ctc-110m')
model_config = model.to_config_dict()

assert model.decoding._is_tdt, "Model is not a TDT model"

kenlm_model_path = Path(test_data_dir) / "asr/kenlm_ngram_lm/parakeet-tdt_ctc-110m-libri-1024.kenlm.tmp.arpa"

fusion_models = []
fusion_models_alpha = []

if use_lm:
fusion_models.append(
NGramGPULanguageModel.from_file(lm_path=kenlm_model_path, vocab_size=model.decoder.blank_idx)
)
fusion_models_alpha.append(10)
if use_boosting_tree:
boosting_tree = BoostingTreeModelConfig(key_phrases_list=["hello", "nvidia"])
fusion_models.append(GPUBoostingTreeModel.from_config(boosting_tree, tokenizer=model.tokenizer))
fusion_models_alpha.append(10)

# Decoder WITH fusion models and alignments
decoding_algo_with_fusion = greedy_decode.GreedyBatchedTDTInfer(
model.decoder,
model.joint,
blank_index=model.decoder.blank_idx,
durations=list(model_config["model_defaults"]["tdt_durations"]),
max_symbols_per_step=10,
preserve_alignments=True,
preserve_frame_confidence=False,
use_cuda_graph_decoder=False, # Use torch impl for this test
fusion_models=fusion_models,
fusion_models_alpha=fusion_models_alpha,
)

durations_list = OmegaConf.to_container(model.decoding.durations, resolve=True)
num_durations = len(durations_list)

with torch.no_grad():
hyps_with_fusion = decoding_algo_with_fusion(encoder_output=encoded, encoded_lengths=encoded_len)[0]
hyp_with_fusion = decode_text_from_greedy_hypotheses(hyps_with_fusion, model.decoding)[0]

# Verify alignments exist
assert hyp_with_fusion.alignments is not None, "Alignments should be preserved with fusion models"

# Verify alignment structure is valid
for t, timestep_alignments in enumerate(hyp_with_fusion.alignments):
for u, (logp, label) in enumerate(timestep_alignments):
assert torch.is_tensor(logp), f"Logits at t={t}, u={u} should be a tensor"
assert torch.is_tensor(label), f"Label at t={t}, u={u} should be a tensor"

# Key assertion: the label should match the argmax of the vocab logits
vocab_logits = logp[:-num_durations]
assert vocab_logits.argmax() == int(label), (
f"Label should match argmax of vocab logits. "
f"Got argmax={vocab_logits.argmax()}, label={int(label)} at t={t}, u={u}"
)

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
Expand Down
Loading