[DO NOT MERGE] Draft of combined cudnn + hipdnn SDPA backend#3216
Draft
rkayaith wants to merge 10 commits into
Draft
[DO NOT MERGE] Draft of combined cudnn + hipdnn SDPA backend#3216rkayaith wants to merge 10 commits into
rkayaith wants to merge 10 commits into
Conversation
Integrate HipDNN with PyTorch when available (requires ROCm 7.12+). Includes: - cmake `USE_HIPDNN` detection - runtime `CUDAHooks::compiledWithHipDNN()` hook - simple `torch.backends.hipdnn` Python module Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com> Signed-off-by: zjgarvey <zjgarvey@gmail.com> Assisted-by: Claude Opus 4.6 <noreply@anthropic.com> remove macro variable AT_HIPDNN_ENABLED The macro was causing internal build failures. The purpose of the macro is to throw a compile error when the generated header isn't included, but we only use the macro in a file where this is the case anyway, so we might as well directly query the associated preprocessor directive. Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Implements forward and backward convolution (2D/3D) through the hipDNN frontend graph API, providing an alternative to the MIOpen backend on ROCm. - Graph-cached convolution: Forward (fprop), backward-data (dgrad), and backward-weight (wgrad) via hipDNN frontend graphs with a thread-local LRU cache (`ParamsLRUCache<K,V>`) to amortize `graph->build()` cost - Dispatch integration: New `ConvBackend::Hipdnn` and `ConvBackend::HipdnnTranspose` variants wired through backend selection, memory format selection, forward/backward switches, and Python enum exposure. hipDNN takes priority over MIOpen when `torch.backends.hipdnn.enabled` is `True` Signed-off-by: zjgarvey <zjgarvey@gmail.com> Assisted-by: Claude Opus 4.6 <noreply@anthropic.com>
Adds support for selecting hipdnn as a backend for `batch_norm`. Unfortunately, batch norm is in a half-migrated state wrt. new dispatch stack. Consequently, this PR adds a new backend-specific frontend op. RFC: https://dev-discuss.pytorch.org/t/rfc-adding-a-batch-norm-backend-revisiting-dispatch-stack-issues/3327 Co-authored-by: Dmitry Nikolaev <dmitry.nikolaev@amd.com> Signed-off-by: zjgarvey <zjgarvey@gmail.com> Assisted-by: Claude Opus 4.6 <noreply@anthropic.com>
These files are guarded by `#if AT_CUDNN_ENABLED()` which is always 0 on ROCm, so only stub implementations compile. Hipify was making text substitutions (cudnnHandle_t → miopenHandle_t, etc.) to code that is entirely dead on ROCm. Include fixes needed to compile without hipify: - RNN.cpp: move CUDAEvent.h, CUDAGraphsUtils.cuh, Exceptions.h into the `#else // AT_CUDNN_ENABLED()` block (only used by real impl) - LossCTC.cpp: remove unused CUDAGraphsUtils.cuh include - BatchNorm.cpp, Module.cpp, attention.cu, attention_backward.cu: remove `#ifdef __HIP_PLATFORM_AMD__` guards that selected hipified header paths (cudnn/hip/MHA.h, cudnn/hip/BatchNorm.h) — use the originals directly since hipify no longer runs on these files The quantized/cudnn/ files additionally had redundant `#ifdef USE_CUDA` guards wrapping the entire file. These are only compiled in CUDA/ROCm builds (gated by cmake), so the guards were dead code. Authored with Claude.
Building with USE_FLASH_ATTENTION=ON on ROCm copies precompiled AOTriton kernel images into torch/lib/aotriton.images/. These are binary GPU kernels for flash and efficient attention, shipped precompiled in the ROCm SDK. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Factor out the stride computation logic from alloc_with_matching_layout into a standalone compute_matching_strides function that returns strides without allocating a tensor. This allows callers that only need the output stride metadata (e.g., graph-based support checks) to avoid unnecessary GPU tensor allocations. For the same-size case, delegates to infer_dense_strides to match empty_like's compaction behavior on non-dense inputs. For different sizes, computes dense strides preserving the reference tensor's dimension ordering. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add an optional 'scale' field to sdp::sdp_params and populate it from every dispatch entry point (_fused_sdp_choice_cpp/_cuda/_xpu, the transformer_encoder helper, and the SDPAParams Python binding). The hipDNN backend reads this in check_cudnn_sdpa_support so the support query and the eventual graph build see the actual scale the user passed, instead of always defaulting to 1/sqrt(head_dim).
Uses shim headers to re-route CUDA includes to hip/hipdnn. Some differences still need conditional logic in the source files:
- hipDNN requires `scale` value to be set on the graph as a constant, while cuDNN accepts it through a host buffer at execute time. This requires an extra field in the cache key for the scale value.
- code dependant on cuDNN version requires logic to handle hipDNN
- ragged/nested tensors aren't supported on hipDNN due to missing APIs:
- set_seq_len_{q,k,v}
- set_padding_mask
- set_ragged_offset
- cuDNN constrains are checked in pytorch, rather than through API queries
- increases coupling, but *does* allow requirements to be checked symbolically without requiring concrete dimension values
Adds `test_fused_attention_custom_scale` parameterized over PLATFORM_SPECIFIC_SDPA (flash, efficient, cudnn). Each fused backend runs SDPA with a non-default `scale=` argument and is compared against the math backend with the same scale. No existing PyTorch SDPA test exercises a non-default scale on a fused/cuDNN backend. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
|
Jenkins build for 77cc7f93bb76bf91e080646aa085272c7516ae69 commit finished as FAILURE |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Primary changes are here: 9bb8752
This adds extra include paths to cmake, where shim headers are placed. The shim headers re-route CUDA includes to the HIP equivalent, with the necessary defines etc. to preserve compatibility. Only shims necessary for SDPA were added.
Key files to look at:
USE_HIPDNN)USE_HIPDNN)The key differences ended up being:
hipDNN requires
scaleto be set on the graph as a constant, while cuDNN binds it through a host buffer at execute time. This requires an extra field in the cache key for the scale value + the plumbing to populate it.SDPA_attributes::set_attn_scale(attn_scale)+{SCALE, &scaling_factor}in the variant_pack: MHA.cpp:551.SDPA_attributes::set_attn_scale_value(scaling_factor): MHA.cpp:540.MHAParams::scaling_factorunder#ifdef USE_HIPDNN).params.scaling_factor: MHA.cpp:291.Code dependent on cuDNN frontend version needs hipDNN handled too:
SDPA_attributes::set_generate_stats(in hipdnn and newer cudnn) vsSDPA_attributes::set_is_inference(older cudnn): MHA.cpp:530-534.Graph::get_workspace_size: out-param form (hipdnn / newer cudnn) vs by-value (older cudnn): MHA.cpp:1570-1575Ragged/nested tensor APIs don't exist in hipDNN yet, calls have to be
ifdefd out:SDPA_attributes::set_seq_len_{q,kv}(and the same onSDPA_backward_attributes)SDPA_attributes::set_padding_mask(and the same onSDPA_backward_attributes)Tensor_attributes::set_ragged_offsetcuDNN constraints are slightly different. cuDNN checks kernel support in PyTorch rather than through API queries.
at::native::check_cudnn_sdpa_supportat the end of the constraint chain: sdp_utils.cpp:833. It builds the graph and runsGraph::is_supported_ext. On cuDNN it short-circuits toreturn true(MHA.cpp:2051-2053).sdp_utils.cpp: check_hipdnn_enabled, check_no_nested_inputs_hipdnn, check_dtypes_hipdnn.hipDNN has slight differences in a few graph APIs that the shim handles. All in cudnn_frontend.h:
Graph::check_support(handle)andGraph::build_plans(handle)take a handle; hipDNN's don't. Shim adds handle-taking overloads: cudnn_frontend.h:34-35.Graph::query_tensor_attributes_of_uid(uid, attrs)is a per-uid out-param query; hipDNN only offers a one-shotGraph::getTensorsByUid()map. Shim wraps it: cudnn_frontend.h:39-50.HeurMode_t::A(recommended heuristic) maps to hipDNN'sFALLBACK. Shim shadows the enum with a struct so call sites can usefe::HeurMode_t::Auniformly: cudnn_frontend.h:55-58.Besides the conditional logic, there's some code that's technically there on both backends but could be simplified if targeting only one:
hipDNN doesn't support nested tensors. A bunch of nested-tensor code is left in but is effectively dead on hipDNN and could just be deleted. Functions whose bodies are entirely cuDNN-only (entry-point bails with
TORCH_CHECK(false, ...)on hipDNN):build_graph_nestedtensorbuild_graph_backward_nestedtensorrun_cudnn_SDP_fprop_nestedtensorrun_cudnn_SDP_bprop_nestedtensorSame goes for the ragged-in-dense path.
use_raggedis always false on hipDNN, so the supporting logic and cache-key field are dead too:use_ragged_in_denseMHAParams::use_raggedhipDNN wants to query graph support before actual execution, which complicates logic involving tensors that haven't been allocated yet.
build_graphandbuild_graph_backward(originallyo,softmaxstats,dropoutseed,dropoutoffsetfor the forward; plusdO,dQ,dK,dVfor the backward). Each removed tensor's dim/stride is now derived fromb/h/s_q/s_kv/d_qk/d_v+ the dtype info inside the build functions, instead of read directly off a tensor:O: shape{b, h, s_q, d_v}, strides viacompute_matching_stridesagainst Q's layout: MHA.cpp:619-622 (forward), MHA.cpp:1053-1057 (backward, also reused fordO).softmaxstats/Stats: contiguous{b, h, s_q, 1}(matchesat::emptyallocation): MHA.cpp:623-628.dropoutseed/dropoutoffset: hardcoded as{1, 1, 1, 1}with dtypeINT64based on the documented allocation contract fromattention.cpp/attention.cu: MHA.cpp:577-590 (forward), MHA.cpp:1022-1033 (backward).dQ/dK/dV: reuseq.sizes()/q.strides()etc. (the gradient tensors areat::empty_like(q)so their layout matches): MHA.cpp:1145-1147.