Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e7382f9
Move fusion models to the base class
artbataev Jan 21, 2026
e45219c
Clean up
artbataev Jan 21, 2026
8348900
Unify implementation: torch and CUDA graphs
artbataev Jan 21, 2026
b2ef196
BatchedHyps: handle step confidence
artbataev Jan 21, 2026
3eb333a
Refactor
artbataev Jan 21, 2026
73968de
Refactor
artbataev Jan 26, 2026
89d6b40
RNN-T: use confidence without alignments data structure
artbataev Jan 26, 2026
3dbc746
Fix TDT inconsistency: store alignment logits with fusion models scor…
artbataev Jan 26, 2026
310d6ee
Merge branch 'main' into vbataev/rnnt_decoding_refactor
artbataev Jan 26, 2026
20030c5
Always store labels in alignments
artbataev Jan 26, 2026
0aa81f4
Fix confidence
artbataev Jan 26, 2026
9d4bd7a
Make confidence consistent
artbataev Jan 26, 2026
f397f14
Unify RNN-T and TDT torch_impl
artbataev Jan 26, 2026
f460a7e
Unify RNN-T and TDT
artbataev Jan 26, 2026
0e69ab5
Fix TDT decoding
artbataev Jan 26, 2026
d197dce
Fix tests
artbataev Jan 26, 2026
38fec91
Apply isort and black reformatting
artbataev Jan 26, 2026
35e57b0
Fix TDT confidence without blank
artbataev Jan 26, 2026
c368304
Merge remote-tracking branch 'origin/vbataev/rnnt_decoding_refactor' …
artbataev Jan 26, 2026
a2fa128
Merge branch 'main' into vbataev/rnnt_decoding_refactor
artbataev Jan 26, 2026
902ccd8
Merge branch 'main' into vbataev/rnnt_decoding_refactor
artbataev Jan 27, 2026
ed1dd08
Fix after merge
artbataev Jan 27, 2026
31a6764
Revert unnecessary test change
artbataev Jan 27, 2026
941bb6e
Bugfix
artbataev Jan 27, 2026
ea6cff4
Merge branch 'main' into vbataev/rnnt_decoding_refactor
artbataev Feb 3, 2026
ac3b78b
Merge branch 'main' into vbataev/rnnt_decoding_refactor
artbataev Apr 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
exclude_blank_from_confidence=self.exclude_blank_from_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True),
Expand All @@ -464,6 +465,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu
),
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
exclude_blank_from_confidence=self.exclude_blank_from_confidence,
include_duration=self.tdt_include_token_duration,
include_duration_confidence=self.tdt_include_duration_confidence,
confidence_method_cfg=self.confidence_method_cfg,
Expand Down Expand Up @@ -815,18 +817,40 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
"""
if self._is_tdt:
# if self.tdt_include_duration_confidence is True then frame_confidence elements consist of two numbers
maybe_pre_aggregate = (
(lambda x: self._aggregate_confidence(x)) if self.tdt_include_duration_confidence else (lambda x: x)
)
for hyp in hypotheses_list:
token_confidence = []
# trying to recover frame_confidence according to alignments
subsequent_blank_confidence = []
# going backwards since <blank> tokens are considered belonging to the last non-blank token.
for fc, fa in zip(hyp.frame_confidence[::-1], hyp.alignments[::-1]):
# there is only one score per frame most of the time
if len(fa) > 1:
for i, a in reversed(list(enumerate(fa))):
if self.exclude_blank_from_confidence and all(
hyp.non_blank_step_confidence_precomputed is not None for hyp in hypotheses_list
):
for hyp in hypotheses_list:
hyp.token_confidence = hyp.non_blank_step_confidence_precomputed
else:
maybe_pre_aggregate = (
(lambda x: self._aggregate_confidence(x))
if self.tdt_include_duration_confidence
else (lambda x: x)
)
for hyp in hypotheses_list:
token_confidence = []
# trying to recover frame_confidence according to alignments
subsequent_blank_confidence = []
# going backwards since <blank> tokens are considered belonging to the last non-blank token.
for fc, fa in zip(hyp.frame_confidence[::-1], hyp.alignments[::-1]):
# there is only one score per frame most of the time
if len(fa) > 1:
for i, a in reversed(list(enumerate(fa))):
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
elif not subsequent_blank_confidence:
token_confidence.append(maybe_pre_aggregate(fc[i]))
else:
token_confidence.append(
self._aggregate_confidence(
[maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence
)
)
subsequent_blank_confidence = []
else:
i, a = 0, fa[0]
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
Expand All @@ -839,20 +863,8 @@ def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothes
)
)
subsequent_blank_confidence = []
else:
i, a = 0, fa[0]
if a[-1] == self.blank_id:
if not self.exclude_blank_from_confidence:
subsequent_blank_confidence.append(maybe_pre_aggregate(fc[i]))
elif not subsequent_blank_confidence:
token_confidence.append(maybe_pre_aggregate(fc[i]))
else:
token_confidence.append(
self._aggregate_confidence([maybe_pre_aggregate(fc[i])] + subsequent_blank_confidence)
)
subsequent_blank_confidence = []
token_confidence = token_confidence[::-1]
hyp.token_confidence = token_confidence
token_confidence = token_confidence[::-1]
hyp.token_confidence = token_confidence
else:
if self.exclude_blank_from_confidence:
for hyp in hypotheses_list:
Expand Down
10 changes: 8 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def __init__(
max_symbols_per_step: Optional[int] = None,
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
exclude_blank_from_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = True,
Expand All @@ -629,6 +630,7 @@ def __init__(

self.use_cuda_graph_decoder = use_cuda_graph_decoder
self.loop_labels = loop_labels
self.exclude_blank_from_confidence = exclude_blank_from_confidence

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
Expand All @@ -643,7 +645,8 @@ def __init__(
blank_index=self._blank_index,
max_symbols_per_step=self.max_symbols,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
preserve_step_confidence=preserve_frame_confidence,
exclude_blank_from_confidence=self.exclude_blank_from_confidence,
confidence_method_cfg=confidence_method_cfg,
allow_cuda_graphs=self.use_cuda_graph_decoder,
fusion_models=fusion_models,
Expand Down Expand Up @@ -2839,6 +2842,7 @@ def __init__(
max_symbols_per_step: Optional[int] = None,
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
exclude_blank_from_confidence: bool = False,
include_duration: bool = False,
include_duration_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
Expand All @@ -2859,6 +2863,7 @@ def __init__(
self.durations = durations
self.include_duration = include_duration
self.include_duration_confidence = include_duration_confidence
self.exclude_blank_from_confidence = exclude_blank_from_confidence

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
Expand All @@ -2873,7 +2878,8 @@ def __init__(
durations=self.durations,
max_symbols_per_step=self.max_symbols,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
preserve_step_confidence=preserve_frame_confidence,
exclude_blank_from_confidence=self.exclude_blank_from_confidence,
include_duration=include_duration,
include_duration_confidence=include_duration_confidence,
confidence_method_cfg=confidence_method_cfg,
Expand Down
Loading
Loading