[MagpieTTS] Load Whisper with torch_dtype="auto" for fp16 inference#15680
Open
matteolippi wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Open
[MagpieTTS] Load Whisper with torch_dtype="auto" for fp16 inference#15680matteolippi wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
matteolippi wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
HF whisper-large-v3 ships fp16 weights, but from_pretrained without torch_dtype defaults to fp32 upcast, which doubles VRAM (~3 GB) and prevents SDPA from dispatching to the flash kernel (flash-SDPA is fp16/bf16 only). Passing torch_dtype="auto" tells HuggingFace to load the model in the dtype declared by the checkpoint's config.json (float16 for whisper-large-v3). Verified on A100 with 100 CML-TTS IT samples: -3.5 GB peak VRAM, +8% throughput, 100/100 transcripts identical to fp32 baseline. Signed-off-by: matteolippi <matteolippi.science@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do?
Pass
torch_dtype="auto"toWhisperForConditionalGeneration.from_pretrainedin the MagpieTTS preference-optimization models, so Whisper loads in its
native fp16 dtype instead of being implicitly upcast to fp32.
HF whisper-large-v3 ships fp16 weights, but
from_pretrainedwithouttorch_dtypedefaults to an fp32 upcast, which doubles VRAM (~3 GB extra)and prevents
torch.nn.functional.scaled_dot_product_attentionfromdispatching to the flash kernel (flash-SDPA is fp16/bf16 only; fp32 falls
back to the mem-efficient backend).
torch_dtype="auto"tells HuggingFaceto load the model in the dtype declared by the checkpoint's
config.json(
float16for whisper-large-v3).Verified on A100 with 100 CML-TTS Italian samples (batch_size=1, greedy):
Δ vs fp32: −3.49 GB peak VRAM, +8.0 % throughput. Char-level Levenshtein
distance per sample, fp32 ↔ fp16: 0.00 mean, 100/100 identical — no
quality regression. Verified via torch.profiler that the default SDPA dispatch picks mem-efficient in fp32 and flash in fp16.
Collection: tts
Changelog
nemo/collections/tts/models/magpietts_preference_optimization.py: addtorch_dtype="auto"to all threeWhisperForConditionalGeneration.from_pretrainedcall sites (
MagpieTTSModelOfflinePODataGen.__init__,MagpieTTSModelOnlinePO.__init__reward-ASR path,MagpieTTSModelOnlinePO.__init__load_whisper_modelpath).Usage
No user-facing API change. No config changes required for existing users.
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
transformers, which is already imported behind the existingif cfg.get(...)gates.PR Type:
Who can review?
cc @blisc @okuchaiev for TTS review.
Additional Information