-
Notifications
You must be signed in to change notification settings - Fork 668
Enable fused RMSNorm dLN + add through CUDNN #2778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8525799
debe130
793d4cd
2134cc6
118a108
a16b8d8
71db968
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -395,6 +395,23 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor | |
| std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); | ||
| if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned."); | ||
| } | ||
| // Fuse the add for BackwardAdd stage | ||
| if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { | ||
| NVTE_CHECK(cudnnGetVersion() >= 92100, | ||
| "Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion()); | ||
|
|
||
| _add = _graph.tensor(fe::graph::Tensor_attributes() | ||
| .set_name("add") | ||
| .set_dim({batch_dim, hidden_dim, 1, 1}) | ||
| .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) | ||
| .set_data_type(get_cudnn_fe_dtype(wtype))); | ||
| auto add_options = fe::graph::Pointwise_attributes() | ||
| .set_mode(fe::PointwiseMode_t::ADD) | ||
| .set_compute_data_type(get_cudnn_fe_dtype(ctype)); | ||
| auto _dx_with_add = _graph.pointwise(_dx, _add, add_options); | ||
| _dx->set_output(false).set_data_type(get_cudnn_fe_dtype(itype)); | ||
| _dx = _dx_with_add; | ||
| } | ||
| _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); | ||
| _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); | ||
| } | ||
|
|
@@ -467,13 +484,16 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_ | |
| void* rsigma_dptr, void* dx_dptr, void* dz_dptr, | ||
| void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, | ||
| void* workspace_dptr, cudaStream_t stream) { | ||
| // cuDNN does not currently support fused backward+add | ||
| NVTE_CHECK(add_dptr == nullptr); | ||
|
|
||
| // Binding data pointers to graph tensors | ||
| _variant_pack = { | ||
| {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; | ||
|
Comment on lines
488
to
489
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the variant pack initialization However, a concern worth verifying: in the Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
|
|
||
| // Bind the add tensor for fused backward+add | ||
| if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { | ||
| NVTE_CHECK(add_dptr != nullptr, "add_dptr must not be null for BackwardAdd"); | ||
| _variant_pack.insert({{_add, add_dptr}}); | ||
| } | ||
|
|
||
| if (_zero_centered) | ||
| _variant_pack.insert({{_scalar_offset, reinterpret_cast<void*>(this->_scalar_dptr.get())}, | ||
| {_gamma_zero, gamma_dptr}}); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
NVTE_CHECK(cudnnGetVersion() >= 92100, ...)fires duringCudnnNormalizationPlanconstruction, which is called fromgetNormalizationPlan. If a user setsNVTE_NORM_BWD_USE_CUDNN=1on a system with cuDNN < 9.21, the entirermsnorm_bwd_addcall will throw a hard error rather than gracefully falling back to the TE kernel. Unlike typical cuDNN support checks elsewhere in the codebase that gate entry into the cuDNN path, this check is buried inside plan construction where there is no recovery mechanism.Consider validating the cuDNN version at the top of
rmsnorm_bwd_addinrmsnorm_api.cpp(inside theuse_cudnn_norm_bwd()branch) and falling back to the TE backend when the version requirement isn't met, consistent with how other backend-selection logic is structured.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the best behaviour here is. If
NVTE_NORM_BWD_USE_CUDNN=1is set, and we cannot use CUDNN, I would expect it to fail early, and for the user to remove that env var, rather than silently falling back to TE.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right – if a user explicitly sets
NVTE_NORM_BWD_USE_CUDNN=1, they're requesting a specific backend and should expect it to fail loudly rather than silently fall back. The explicit opt-in makes the current behavior reasonable.One clarification: does the regular (non-BackwardAdd) cuDNN normalization work with cuDNN < 9.21? If so, a user might set
NVTE_NORM_BWD_USE_CUDNN=1for general cuDNN norm support and be surprised when BackwardAdd specifically fails. In that case, the error message could be slightly more actionable:But if BackwardAdd is only used when explicitly requested (not auto-selected based on the operation pattern), then the current behavior and message are fine as-is.