Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,23 @@ CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Nor
std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]);
if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned.");
}
// Fuse the add for BackwardAdd stage
if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) {
NVTE_CHECK(cudnnGetVersion() >= 92100,
"Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion());
Comment on lines +399 to +401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Version check throws inside constructor with no fallback path

The NVTE_CHECK(cudnnGetVersion() >= 92100, ...) fires during CudnnNormalizationPlan construction, which is called from getNormalizationPlan. If a user sets NVTE_NORM_BWD_USE_CUDNN=1 on a system with cuDNN < 9.21, the entire rmsnorm_bwd_add call will throw a hard error rather than gracefully falling back to the TE kernel. Unlike typical cuDNN support checks elsewhere in the codebase that gate entry into the cuDNN path, this check is buried inside plan construction where there is no recovery mechanism.

Consider validating the cuDNN version at the top of rmsnorm_bwd_add in rmsnorm_api.cpp (inside the use_cudnn_norm_bwd() branch) and falling back to the TE backend when the version requirement isn't met, consistent with how other backend-selection logic is structured.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the best behaviour here is. If NVTE_NORM_BWD_USE_CUDNN=1 is set, and we cannot use CUDNN, I would expect it to fail early, and for the user to remove that env var, rather than silently falling back to TE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right – if a user explicitly sets NVTE_NORM_BWD_USE_CUDNN=1, they're requesting a specific backend and should expect it to fail loudly rather than silently fall back. The explicit opt-in makes the current behavior reasonable.

One clarification: does the regular (non-BackwardAdd) cuDNN normalization work with cuDNN < 9.21? If so, a user might set NVTE_NORM_BWD_USE_CUDNN=1 for general cuDNN norm support and be surprised when BackwardAdd specifically fails. In that case, the error message could be slightly more actionable:

"Fused BackwardAdd requires cuDNN >= 9.21.0 (found " + std::to_string(cudnnGetVersion()) + 
"). Either upgrade cuDNN or unset NVTE_NORM_BWD_USE_CUDNN to use the TE backend."

But if BackwardAdd is only used when explicitly requested (not auto-selected based on the operation pattern), then the current behavior and message are fine as-is.


_add = _graph.tensor(fe::graph::Tensor_attributes()
.set_name("add")
.set_dim({batch_dim, hidden_dim, 1, 1})
.set_stride({hidden_dim, 1, hidden_dim, hidden_dim})
.set_data_type(get_cudnn_fe_dtype(wtype)));
auto add_options = fe::graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::ADD)
.set_compute_data_type(get_cudnn_fe_dtype(ctype));
auto _dx_with_add = _graph.pointwise(_dx, _add, add_options);
_dx->set_output(false).set_data_type(get_cudnn_fe_dtype(itype));
_dx = _dx_with_add;
}
_dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
_dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype));
}
Expand Down Expand Up @@ -467,13 +484,16 @@ void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_
void* rsigma_dptr, void* dx_dptr, void* dz_dptr,
void* add_dptr, void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
// cuDNN does not currently support fused backward+add
NVTE_CHECK(add_dptr == nullptr);

// Binding data pointers to graph tensors
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
Comment on lines 488 to 489
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _dgamma mapped in variant pack but not an output tensor for BackwardAdd stage

In the variant pack initialization {_dgamma, dgamma_dptr} is always included. For the BackwardAdd stage, _dgamma->set_output(true) is called from line 417 (_dgamma->set_output(true).set_data_type(...)), so cuDNN will write to dgamma_dptr. This looks fine.

However, a concern worth verifying: in the BackwardAdd path, _dx (the member) has been reassigned to _dx_with_add. The original intermediate rmsnorm _dx tensor (the input to the pointwise add) was set to set_output(false). It is not in the _variant_pack, which is correct — cuDNN handles it as an internal virtual tensor. No binding is required for it. This is working correctly but is subtle; a code comment explaining that the intermediate dx does not need a binding would improve maintainability.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


// Bind the add tensor for fused backward+add
if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) {
NVTE_CHECK(add_dptr != nullptr, "add_dptr must not be null for BackwardAdd");
_variant_pack.insert({{_add, add_dptr}});
}

if (_zero_centered)
_variant_pack.insert({{_scalar_offset, reinterpret_cast<void*>(this->_scalar_dptr.get())},
{_gamma_zero, gamma_dptr}});
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/normalization/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
std::shared_ptr<fe::graph::Tensor_attributes> _z_mx_row, _z_mx_col, _sf_row, _sf_col;
const bool _training;
// BWD
std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta;
std::shared_ptr<fe::graph::Tensor_attributes> _dz, _dx, _dgamma, _dbeta, _add;

fe::graph::Graph _graph;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> _variant_pack;
Expand Down
23 changes: 14 additions & 9 deletions transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,21 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
CheckOutputTensor(*dgamma, "dgamma");
}

// cuDNN does not currently support fused backward+add
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;

// TE backend does not currently support zero_centered_gamma_in_weight_dtype
NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(),
"zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add");

bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dgamma->data.dptr, add.data.dptr);
NVTE_Norm_Backend norm_backend;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
if (use_cudnn_norm_bwd()) {
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else {
norm_backend = NVTE_Norm_Backend::Te;
// TE backend does not currently support zero_centered_gamma_in_weight_dtype
NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(),
"zero_centered_gamma_in_weight_dtype is currently not supported "
"for rmsnorm_bwd_add with TE backend");
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dgamma->data.dptr, add.data.dptr);
}

auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan(
norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd,
Expand Down
Loading