Skip to content

fix(stable_audio): align batched initial audio with prompts (#13629)#13659

Open
Anai-Guo wants to merge 1 commit intohuggingface:mainfrom
Anai-Guo:fix-stable-audio-prompt-alignment
Open

fix(stable_audio): align batched initial audio with prompts (#13629)#13659
Anai-Guo wants to merge 1 commit intohuggingface:mainfrom
Anai-Guo:fix-stable-audio-prompt-alignment

Conversation

@Anai-Guo
Copy link
Copy Markdown

Summary

Fixes Issue 1 in #13629.

StableAudioPipeline.prepare_latents() was expanding batched initial audio with encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)), which produces an interleaved layout [audio0, audio1, audio0, audio1]. The corresponding text+duration embeds are expanded per prompt by encode_prompt() to [prompt0, prompt0, prompt1, prompt1]. For batched audio-to-audio with num_waveforms_per_prompt > 1, prompts and initial audio became misaligned, so each generation could be conditioned on another prompt’s initial audio. Existing tests only assert output shape and missed it.

Fix

Switch repeat to repeat_interleave(..., dim=0) so batched initial audio expands as [audio0, audio0, audio1, audio1], matching the prompt expansion.

Verification

The reproduction snippet from #13629 now prints the expected order:

import torch
from diffusers import StableAudioPipeline
# (same DummyVAE / SimpleNamespace scaffolding as in the issue)
print(latents[:, 0, 0].tolist())  # [10.0, 10.0, 20.0, 20.0]

num_waveforms_per_prompt == 1 is unchanged because repeat_interleave and repeat are equivalent in that case.

🤖 Generated with Claude Code

@github-actions github-actions Bot added pipelines size/S PR with diff < 50 LOC labels Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pipelines size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant