Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3170,6 +3170,7 @@ def __init__(
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
bias: bool = True,
Expand Down Expand Up @@ -3258,6 +3259,7 @@ def __init__(
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_name="qkv",
Expand Down Expand Up @@ -3289,6 +3291,7 @@ def __init__(
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_overlap_rs_dgrad=ub_overlap_rs_dgrad,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_name="qkv",
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def initialize_ub(
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]

def get_method(name):
for method, names in methods.items():
Expand Down Expand Up @@ -207,6 +208,14 @@ def add_ub(
)
_ub_communicators[name] = ub_obj

if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk':
wgrad_name = name.replace('dgrad','wgrad')
assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_reduce_scatter_overlap.append(name)

for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
if ub_cfgs is not None and name in ub_cfgs:
ub_cfg = ub_cfgs[name]
Expand Down
63 changes: 53 additions & 10 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def forward(
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_overlap_rs_dgrad: bool,
ub_overlap_ag: bool,
ub_name: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
Expand Down Expand Up @@ -293,6 +294,7 @@ def forward(
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
Expand Down Expand Up @@ -344,9 +346,15 @@ def backward(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)

if ctx.ub_bulk_dgrad:
if ctx.ub_overlap_rs_dgrad:
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_overlap_rs_dgrad = False
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1 or not weight.requires_grad:
ctx.ub_bulk_dgrad = False
if ctx.ub_bulk_dgrad:
dim_size = list(ln_out.size())
Expand Down Expand Up @@ -393,9 +401,35 @@ def backward(
if ctx.ub_bulk_wgrad: # allocate dgrad output
ub_obj_dgrad = get_ub(ctx.ub_name+"_wgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
elif ctx.ub_overlap_rs_dgrad:
ub_obj_dgrad = get_ub(ctx.ub_name+"_dgrad")
dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
else:
dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device)

if ctx.ub_bulk_dgrad:
ub_algo = tex.UbufOverlapAlgo.BULK_OVERLAP_AG
ub_obj = ub_obj_lnout
elif ctx.ub_overlap_rs_dgrad:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(1)
rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=grad_output.device)
if ub_obj_dgrad.is_p2p_overlap():
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
ub_obj = ub_obj_dgrad
else:
ub_algo = None
ub_obj = None

if ctx.fp8:
fp8_dtype_forward = get_fp8_te_dtype(
ctx.fp8_meta["recipe"], fprop_tensor=True
Expand All @@ -405,14 +439,14 @@ def backward(
)
out_index, meta_tensor, out_te_type, out_type = (
None, None, None, ctx.activation_dtype)
if ctx.ub_bulk_wgrad and ub_obj_dgrad.is_fp8_ubuf():
if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf():
out_index = tex.FP8BwdTensors.GRAD_INPUT1
meta_tensor = ctx.fp8_meta["scaling_bwd"]
out_te_type = fp8_dtype_backward
out_type = torch.uint8
ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])

# DGRAD: Evaluated unconditionally to feed into Linear backward
# DGRAD: Evaluated unconditionally to feed into Linear backward vasu
_ = tex.fp8_gemm(
weight_t_fp8._data,
fwd_scale_inverses,
Expand All @@ -426,8 +460,9 @@ def backward(
get_workspace(),
out=dgrad,
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None,
ub_algo=ub_algo,
ub=ub_obj,
extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None,
out_index=out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = out_te_type,
Expand All @@ -443,8 +478,9 @@ def backward(
out=dgrad,
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
ub_algo=ub_algo,
ub=ub_obj,
extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None,
)
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
Expand All @@ -453,7 +489,7 @@ def backward(
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if not ctx.ub_bulk_dgrad and handle is not None:
handle.wait()
if not ctx.ub_bulk_wgrad:
if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, handle = reduce_scatter_along_first_dim(
Expand Down Expand Up @@ -546,7 +582,10 @@ def backward(
handle.wait()

# LayerNorm gradient
dgrad = dgrad.view(inputmat.shape)
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out.view(inputmat.shape)
else:
dgrad = dgrad.view(inputmat.shape)

# Residual gradient
if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
Expand Down Expand Up @@ -622,6 +661,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down Expand Up @@ -735,6 +775,7 @@ def __init__(
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_overlap_ag: bool = False,
ub_overlap_rs_dgrad: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
Expand All @@ -755,7 +796,8 @@ def __init__(
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_overlap_ag = ub_overlap_ag
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]):
self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name

Expand Down Expand Up @@ -1087,6 +1129,7 @@ def forward(
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_overlap_ag,
self.ub_name,
)
Expand Down
Loading