Summary
torch.compile(fullgraph=True) doesn't work with FlyDSL compiled kernels.
Two distinct failure modes are reproduced:
- Compiling the kernel inside torch.compile
- Calling a pre-compiled kernel inside torch.compile
System configuration
- GPU: MI355X
- OS: Ubuntu 24.04
- PyTorch: 2.10.0+rocm7.2.2.git40d237bf
Repro
Code available here.
Both scripts work correctly in eager mode and only fail at the torch.compile step.
Case 1: compilation inside torch.compile
FLYDSL_RUNTIME_ENABLE_CACHE=0 python tests/kernels/test_gemm_torch_compile.py
Case 2: invocation of pre-compiled kernel inside torch.compile
FLYDSL_RUNTIME_ENABLE_CACHE=0 python tests/kernels/test_gemm_torch_compile_wrapper.py
Summary
torch.compile(fullgraph=True)doesn't work with FlyDSL compiled kernels.Two distinct failure modes are reproduced:
System configuration
Repro
Code available here.
Both scripts work correctly in eager mode and only fail at the torch.compile step.
Case 1: compilation inside torch.compile
Case 2: invocation of pre-compiled kernel inside torch.compile