Skip to content

[DO NOT MERGE] Draft of combined cudnn + hipdnn SDPA backend#3216

Draft
rkayaith wants to merge 10 commits into
ROCm:hipdnn_developfrom
rkayaith:sdpa-unify
Draft

[DO NOT MERGE] Draft of combined cudnn + hipdnn SDPA backend#3216
rkayaith wants to merge 10 commits into
ROCm:hipdnn_developfrom
rkayaith:sdpa-unify

Conversation

@rkayaith
Copy link
Copy Markdown

@rkayaith rkayaith commented May 12, 2026

Primary changes are here: 9bb8752

This adds extra include paths to cmake, where shim headers are placed. The shim headers re-route CUDA includes to the HIP equivalent, with the necessary defines etc. to preserve compatibility. Only shims necessary for SDPA were added.

Key files to look at:

The key differences ended up being:

  • hipDNN requires scale to be set on the graph as a constant, while cuDNN binds it through a host buffer at execute time. This requires an extra field in the cache key for the scale value + the plumbing to populate it.

    • cuDNN binds via SDPA_attributes::set_attn_scale(attn_scale) + {SCALE, &scaling_factor} in the variant_pack: MHA.cpp:551.
    • hipDNN bakes via SDPA_attributes::set_attn_scale_value(scaling_factor): MHA.cpp:540.
    • extra cache key field: MHA.cpp:239 (MHAParams::scaling_factor under #ifdef USE_HIPDNN).
    • Population of params.scaling_factor: MHA.cpp:291.
  • Code dependent on cuDNN frontend version needs hipDNN handled too:

    • SDPA_attributes::set_generate_stats (in hipdnn and newer cudnn) vs SDPA_attributes::set_is_inference (older cudnn): MHA.cpp:530-534.
    • Graph::get_workspace_size: out-param form (hipdnn / newer cudnn) vs by-value (older cudnn): MHA.cpp:1570-1575
  • Ragged/nested tensor APIs don't exist in hipDNN yet, calls have to be ifdefd out:

    • SDPA_attributes::set_seq_len_{q,kv} (and the same on SDPA_backward_attributes)
    • SDPA_attributes::set_padding_mask (and the same on SDPA_backward_attributes)
    • Tensor_attributes::set_ragged_offset
  • cuDNN constraints are slightly different. cuDNN checks kernel support in PyTorch rather than through API queries.

  • hipDNN has slight differences in a few graph APIs that the shim handles. All in cudnn_frontend.h:

    • cuDNN's Graph::check_support(handle) and Graph::build_plans(handle) take a handle; hipDNN's don't. Shim adds handle-taking overloads: cudnn_frontend.h:34-35.
    • cuDNN's Graph::query_tensor_attributes_of_uid(uid, attrs) is a per-uid out-param query; hipDNN only offers a one-shot Graph::getTensorsByUid() map. Shim wraps it: cudnn_frontend.h:39-50.
    • cuDNN's HeurMode_t::A (recommended heuristic) maps to hipDNN's FALLBACK. Shim shadows the enum with a struct so call sites can use fe::HeurMode_t::A uniformly: cudnn_frontend.h:55-58.

Besides the conditional logic, there's some code that's technically there on both backends but could be simplified if targeting only one:

zjgarvey and others added 10 commits April 15, 2026 09:26
Integrate HipDNN with PyTorch when available (requires ROCm 7.12+).

Includes:

- cmake `USE_HIPDNN` detection
- runtime `CUDAHooks::compiledWithHipDNN()` hook
- simple `torch.backends.hipdnn` Python module

Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com>
Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Assisted-by: Claude Opus 4.6 <noreply@anthropic.com>

remove macro variable AT_HIPDNN_ENABLED

The macro was causing internal build failures. The purpose of the macro
is to throw a compile error when the generated header isn't included,
but we only use the macro in a file where this is the case anyway, so we
might as well directly query the associated preprocessor directive.

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Implements forward and backward convolution (2D/3D) through the hipDNN
frontend graph API, providing an alternative to the MIOpen backend on
ROCm.

- Graph-cached convolution: Forward (fprop), backward-data (dgrad),
and backward-weight (wgrad) via hipDNN frontend graphs with a
thread-local LRU cache (`ParamsLRUCache<K,V>`) to amortize
`graph->build()` cost
- Dispatch integration: New `ConvBackend::Hipdnn` and
`ConvBackend::HipdnnTranspose` variants wired through backend selection,
memory format selection, forward/backward switches, and Python enum
exposure. hipDNN takes priority over MIOpen when
`torch.backends.hipdnn.enabled` is `True`

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Assisted-by: Claude Opus 4.6 <noreply@anthropic.com>
Adds support for selecting hipdnn as a backend for `batch_norm`.

Unfortunately, batch norm is in a half-migrated state wrt. new dispatch
stack. Consequently, this PR adds a new backend-specific frontend op.

RFC: https://dev-discuss.pytorch.org/t/rfc-adding-a-batch-norm-backend-revisiting-dispatch-stack-issues/3327

Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com>
Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Assisted-by: Claude Opus 4.6 <noreply@anthropic.com>
These files are guarded by `#if AT_CUDNN_ENABLED()` which is always 0
on ROCm, so only stub implementations compile. Hipify was making
text substitutions (cudnnHandle_t → miopenHandle_t, etc.) to code
that is entirely dead on ROCm.

Include fixes needed to compile without hipify:
- RNN.cpp: move CUDAEvent.h, CUDAGraphsUtils.cuh, Exceptions.h into
  the `#else // AT_CUDNN_ENABLED()` block (only used by real impl)
- LossCTC.cpp: remove unused CUDAGraphsUtils.cuh include
- BatchNorm.cpp, Module.cpp, attention.cu, attention_backward.cu:
  remove `#ifdef __HIP_PLATFORM_AMD__` guards that selected hipified
  header paths (cudnn/hip/MHA.h, cudnn/hip/BatchNorm.h) — use the
  originals directly since hipify no longer runs on these files

The quantized/cudnn/ files additionally had redundant `#ifdef USE_CUDA`
guards wrapping the entire file. These are only compiled in CUDA/ROCm
builds (gated by cmake), so the guards were dead code.

Authored with Claude.
Building with USE_FLASH_ATTENTION=ON on ROCm copies precompiled AOTriton
kernel images into torch/lib/aotriton.images/. These are binary GPU kernels
for flash and efficient attention, shipped precompiled in the ROCm SDK.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Factor out the stride computation logic from alloc_with_matching_layout
into a standalone compute_matching_strides function that returns strides
without allocating a tensor. This allows callers that only need the
output stride metadata (e.g., graph-based support checks) to avoid
unnecessary GPU tensor allocations.

For the same-size case, delegates to infer_dense_strides to match
empty_like's compaction behavior on non-dense inputs. For different
sizes, computes dense strides preserving the reference tensor's
dimension ordering.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add an optional 'scale' field to sdp::sdp_params and populate it from every dispatch entry point (_fused_sdp_choice_cpp/_cuda/_xpu, the transformer_encoder helper, and the SDPAParams Python binding). The hipDNN backend reads this in check_cudnn_sdpa_support so the support query and the eventual graph build see the actual scale the user passed, instead of always defaulting to 1/sqrt(head_dim).
Uses shim headers to re-route CUDA includes to hip/hipdnn. Some differences still need conditional logic in the source files:
- hipDNN requires `scale` value to be set on the graph as a constant, while cuDNN accepts it through a host buffer at execute time. This requires an extra field in the cache key for the scale value.
- code dependant on cuDNN version requires logic to handle hipDNN
- ragged/nested tensors aren't supported on hipDNN due to missing APIs:
  - set_seq_len_{q,k,v}
  - set_padding_mask
  - set_ragged_offset
- cuDNN constrains are checked in pytorch, rather than through API queries
  - increases coupling, but *does* allow requirements to be checked symbolically without requiring concrete dimension values
Adds `test_fused_attention_custom_scale` parameterized over
PLATFORM_SPECIFIC_SDPA (flash, efficient, cudnn). Each fused backend
runs SDPA with a non-default `scale=` argument and is compared against
the math backend with the same scale. No existing PyTorch SDPA test
exercises a non-default scale on a fused/cuDNN backend.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@rocm-repo-management-api
Copy link
Copy Markdown

rocm-repo-management-api Bot commented May 12, 2026

Jenkins build for 77cc7f93bb76bf91e080646aa085272c7516ae69 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants