Skip to content

[MagpieTTS] Load Whisper with torch_dtype="auto" for fp16 inference#15680

Open
matteolippi wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
matteolippi:whisper-fp16-loading
Open

[MagpieTTS] Load Whisper with torch_dtype="auto" for fp16 inference#15680
matteolippi wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
matteolippi:whisper-fp16-loading

Conversation

@matteolippi
Copy link
Copy Markdown
Contributor

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do?

Pass torch_dtype="auto" to WhisperForConditionalGeneration.from_pretrained
in the MagpieTTS preference-optimization models, so Whisper loads in its
native fp16 dtype instead of being implicitly upcast to fp32.

HF whisper-large-v3 ships fp16 weights, but from_pretrained without
torch_dtype defaults to an fp32 upcast, which doubles VRAM (~3 GB extra)
and prevents torch.nn.functional.scaled_dot_product_attention from
dispatching to the flash kernel (flash-SDPA is fp16/bf16 only; fp32 falls
back to the mem-efficient backend). torch_dtype="auto" tells HuggingFace
to load the model in the dtype declared by the checkpoint's config.json
(float16 for whisper-large-v3).

Verified on A100 with 100 CML-TTS Italian samples (batch_size=1, greedy):

dtype weights VRAM peak VRAM median latency throughput
fp32 (current default) 6.31 GB 6.85 GB 857.8 ms 1.12 samp/s
fp16 (this PR) 3.09 GB 3.37 GB 789.5 ms 1.21 samp/s

Δ vs fp32: −3.49 GB peak VRAM, +8.0 % throughput. Char-level Levenshtein
distance per sample, fp32 ↔ fp16: 0.00 mean, 100/100 identical — no
quality regression. Verified via torch.profiler that the default SDPA dispatch picks mem-efficient in fp32 and flash in fp16.

Collection: tts

Changelog

  • nemo/collections/tts/models/magpietts_preference_optimization.py: add
    torch_dtype="auto" to all three WhisperForConditionalGeneration.from_pretrained
    call sites (MagpieTTSModelOfflinePODataGen.__init__,
    MagpieTTSModelOnlinePO.__init__ reward-ASR path,
    MagpieTTSModelOnlinePO.__init__ load_whisper_model path).

Usage

No user-facing API change. No config changes required for existing users.

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests? — N/A: change is a single dtype kwarg replicated across three identical loading blocks; behavior is exercised by existing PO pipelines.
  • Did you add or update any necessary documentation? — N/A: internal load behavior only.
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc) — No new optional deps. Whisper is loaded via transformers, which is already imported behind the existing if cfg.get(...) gates.
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

cc @blisc @okuchaiev for TTS review.

Additional Information

  • Not tied to a specific issue. Discovered while profiling GRPO training memory on multi-node A100 runs.

HF whisper-large-v3 ships fp16 weights, but from_pretrained without
torch_dtype defaults to fp32 upcast, which doubles VRAM (~3 GB) and
prevents SDPA from dispatching to the flash kernel (flash-SDPA is
fp16/bf16 only). Passing torch_dtype="auto" tells HuggingFace to load
the model in the dtype declared by the checkpoint's config.json
(float16 for whisper-large-v3).

Verified on A100 with 100 CML-TTS IT samples: -3.5 GB peak VRAM,
+8% throughput, 100/100 transcripts identical to fp32 baseline.

Signed-off-by: matteolippi <matteolippi.science@gmail.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 8, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants