Skip to content

Support real-valued input in torch.fft.irfft converter#2721

Open
adityasingh2400 wants to merge 1 commit into
apple:mainfrom
adityasingh2400:fix-irfft-real-input
Open

Support real-valued input in torch.fft.irfft converter#2721
adityasingh2400 wants to merge 1 commit into
apple:mainfrom
adityasingh2400:fix-irfft-real-input

Conversation

@adityasingh2400
Copy link
Copy Markdown

Summary

Converting a model that calls torch.fft.irfft on a real-valued tensor fails:

Op "complex_irfft_0" (op_type: complex_irfft) Input data="k_f" expects
tensor or scalar of dtype from type domain ['complex64'] but got tensor[3,fp32]

torch.fft.irfft accepts a real-valued input and treats it as a complex tensor with a zero imaginary part (a one-sided Hermitian spectrum is real at DC). The fft_irfft handler passed the traced input straight to the complex_irfft dialect op, which only accepts complex64, so any real input is rejected at conversion time. The underlying IRFFT lowering already exists and works; only the real-to-complex promotion was missing.

Fixes #2130.

Changes

coremltools/converters/mil/frontend/torch/ops.py

  • In fft_irfft, if the input is not already complex, build a zero imaginary part with mb.fill and wrap it into a complex value via mb.complex before calling complex_irfft. This mirrors what the complex_fft / complex_fftn lowering already does for real inputs and matches PyTorch semantics. Complex inputs are passed through unchanged.

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

  • Added test_fft_irfft_real_input to TestFft, parametrized over n, dim, and norm. Unlike the existing test_fft_basic, the input is not wrapped with torch.complex, so it covers the real-input path from the issue.

Testing

Red/green verified end to end (ct.convert then predict on ComputeUnit.CPU_ONLY, compared against torch.fft.irfft):

  • RED: without the fix the new test fails with the complex_irfft dtype error above.
  • GREEN: with the fix, test_fft_irfft_real_input passes (36/36 cases), and the existing TestFft tests are unaffected.

Numeric agreement against PyTorch across default args, explicit n (pad and trim), every norm mode, and 1D/2D/3D inputs:

  • FLOAT32 compute precision: max abs diff ~1e-7.
  • FLOAT16 compute precision: max abs diff ~1e-3, in line with the existing complex-input irfft path (same DFT-by-matmul lowering).

The original issue snippet now converts and produces [0.15625, 0.78125, 0.78125, 0.78125], matching PyTorch exactly.

Implemented with the help of a coding agent.

torch.fft.irfft accepts a real-valued input and treats it as a complex
tensor with a zero imaginary part. The converter passed the input
straight to the complex_irfft dialect op, which only accepts complex64,
so converting a model that calls irfft on a real tensor failed with:

  Op "complex_irfft_0" (op_type: complex_irfft) Input data=... expects
  tensor or scalar of dtype from type domain ['complex64'] but got
  tensor[3,fp32]

Promote a real input to complex with a zero imaginary part before
lowering, matching PyTorch semantics. Complex inputs are unchanged.

Added a regression test exercising real-input irfft across n, dim and
norm. Fixes apple#2130.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Input error during converting torch.fft.irfft

1 participant