Skip to content

Commit e0792f5

Browse files
committed
Fix the retrieval of overwrite_main_grad: fetch it from the original weight instead of the fp8 (fp4) weight, since the fp8 (fp4) weight may not inherit the required attributes during creation.
1 parent b7214fd commit e0792f5

3 files changed

Lines changed: 4 additions & 4 deletions

File tree

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
476476
use_split_accumulator=wgrad_gemm_use_split_accumulator,
477477
accumulate=(
478478
accumulate_wgrad_into_param_main_grad
479-
if not getattr(weights[0], "overwrite_main_grad", False)
479+
if not getattr(origin_weights[0], "overwrite_main_grad", False)
480480
else False
481481
),
482482
)

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def backward(
864864
"quantization_params": ctx.grad_weight_quantizer,
865865
"accumulate": (
866866
accumulate_wgrad_into_param_main_grad
867-
if not getattr(weight, "overwrite_main_grad", False)
867+
if not getattr(origin_weight, "overwrite_main_grad", False)
868868
else False
869869
),
870870
"layout": "NT",

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ def backward(
12241224
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
12251225
"accumulate": (
12261226
accumulate_wgrad_into_param_main_grad
1227-
if not getattr(fc1_weight, "overwrite_main_grad", False)
1227+
if not getattr(origin_fc1_weight, "overwrite_main_grad", False)
12281228
else False
12291229
),
12301230
"layout": "NT",
@@ -1471,7 +1471,7 @@ def fc2_wgrad_gemm(
14711471
"quantization_params": ctx.fc1_grad_weight_quantizer,
14721472
"accumulate": (
14731473
accumulate_wgrad_into_param_main_grad
1474-
if not getattr(fc2_weight, "overwrite_main_grad", False)
1474+
if not getattr(origin_fc2_weight, "overwrite_main_grad", False)
14751475
else False
14761476
),
14771477
"layout": "NT",

0 commit comments

Comments
 (0)