Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 208 additions & 6 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,185 @@ def test_make_graphed_callables_with_kwargs(
assert_all_equal(outputs, graph_outputs)


def test_make_graphed_callables_returns_owned_parameter_grads() -> None:
"""Parameter grads returned from graph replay must not alias static graph buffers."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad = seen_grads[0]
first_grad_ptr = first_grad.data_ptr()
first_grad_snapshot = first_grad.clone()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert first_grad.data_ptr() == first_grad_ptr
assert seen_grads[1].data_ptr() != first_grad_ptr
torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0)
finally:
hook.remove()
reset_graphs(model)


def test_make_graphed_callables_accumulates_owned_parameter_grads() -> None:
"""Parameter grad accumulation must not reuse overwritten static graph buffers."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)

input_1 = generate_data(model_config, dtype, requires_grad=False)
grad_1 = generate_data(model_config, dtype, requires_grad=False)
input_2 = generate_data(model_config, dtype, requires_grad=False)
grad_2 = generate_data(model_config, dtype, requires_grad=False)
expected_grad = torch.einsum("...o,...i->oi", grad_1, input_1) + torch.einsum(
"...o,...i->oi", grad_2, input_2
)

try:
model.zero_grad(set_to_none=True)
model(input_1).backward(grad_1)
model(input_2).backward(grad_2)
torch.testing.assert_close(model.weight.grad, expected_grad, rtol=0, atol=0)
finally:
reset_graphs(model)


def test_make_graphed_callables_preserves_skipped_parameter_grad_alias() -> None:
"""Delayed-wgrad parameters are excluded from returned-grad clone handling."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model.weight.skip_backward_post_hook = True
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad_ptr = seen_grads[0].data_ptr()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert seen_grads[1].data_ptr() == first_grad_ptr
finally:
hook.remove()
reset_graphs(model)


def test_make_graphed_callables_snapshots_parameter_grad_clone_policy() -> None:
"""Parameter grad clone policy is fixed at capture time."""
reset_rng_states()
model_config = model_configs["small"]
dtype = torch.float32
model = torch.nn.Linear(
model_config.hidden_size,
model_config.hidden_size,
bias=False,
device="cuda",
dtype=dtype,
)
model = make_graphed_callables(
model,
(generate_data(model_config, dtype, warmup=True, requires_grad=False),),
)
model.weight.skip_backward_post_hook = True

seen_grads = []

def save_grad(grad):
seen_grads.append(grad)
return grad

hook = model.weight.register_hook(save_grad)
try:
output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 1
first_grad = seen_grads[0]
first_grad_ptr = first_grad.data_ptr()
first_grad_snapshot = first_grad.clone()

model.zero_grad(set_to_none=True)

output = model(generate_data(model_config, dtype, requires_grad=False))
output.backward(generate_data(model_config, dtype, requires_grad=False))

assert len(seen_grads) == 2
assert seen_grads[1].data_ptr() != first_grad_ptr
torch.testing.assert_close(first_grad, first_grad_snapshot, rtol=0, atol=0)
finally:
hook.remove()
reset_graphs(model)


def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
*,
with_graph: bool,
model_config: ModelConfig,
dtype: torch.dtype,
) -> List[torch.Tensor]:
reuse_graph_input_output_buffers: bool = False,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Simulate Megatron-LM interleaved pipeline parallelism."""
reset_rng_states()

Expand Down Expand Up @@ -675,6 +848,7 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(
sample_args,
allow_unused_input=True,
_order=layer_order,
_reuse_graph_input_output_buffers=reuse_graph_input_output_buffers,
)
layer_forwards = {
(i // num_microbatches, i % num_microbatches): forward
Expand All @@ -701,11 +875,15 @@ def _test_cuda_graphs_with_interleaved_pipeline_parallelism(

# Cache for layer outputs.
outputs = {}
output_snapshots = {} if reuse_graph_input_output_buffers else None

def forward(layer_idx: int, microbatch_idx: int):
"""Helper function for forward steps"""
idxs = (layer_idx, microbatch_idx)
outputs[idxs] = layer_forwards[idxs](inputs[idxs])
if output_snapshots is not None:
# Reused graph output buffers are only valid until their corresponding backward.
output_snapshots[idxs] = outputs[idxs].detach().clone()

def backward(layer_idx: int, microbatch_idx: int):
"""Helper function for backward steps"""
Expand All @@ -728,11 +906,13 @@ def backward(layer_idx: int, microbatch_idx: int):
# Optimizer step.
optimizer.step()

outputs = [y for _, y in sorted(outputs.items())]
outputs = get_outputs(model, outputs)
output_values = output_snapshots if output_snapshots is not None else outputs
output_values = [y for _, y in sorted(output_values.items())]
outputs = get_outputs(model, output_values)
final_weights = [param.detach().clone() for param in model.parameters()]
if with_graph:
reset_graphs(layer_forwards)
return outputs
return outputs, final_weights


def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
Expand All @@ -743,12 +923,34 @@ def test_make_graphed_callables_with_interleaved_pipeline_parallelism(
"""Test CUDA graphs with Megatron-LM interleaved pipeline parallelism."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
assert_all_equal(weights, graph_weights)


def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
) -> None:
"""Test CUDA graphs with reused input/output buffers."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs, weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
graph_outputs, graph_weights = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
reuse_graph_input_output_buffers=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
Comment on lines +938 to +955
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Reused-buffer test only validates forward outputs, not gradient correctness

test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers compares output_snapshots (forward tensors cloned before the corresponding backward) against the eager baseline. If the clone-on-return logic in Graphed.backward had a bug specifically in the _reuse_graph_input_output_buffers + pipeline path (e.g., gradient accumulation or an incorrect static buffer being read), weights would diverge but the test would still pass. A weight-equality check after one full schedule would strengthen confidence in the gradient path for this mode.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4077b85. The interleaved pipeline helper now returns final weights in addition to outputs, and the reused-buffer test compares graph/eager final weights to cover gradient correctness. Full tests/pytorch/test_cuda_graphs.py passed on H100: 415 passed, 423 skipped.

assert_all_equal(weights, graph_weights)
59 changes: 55 additions & 4 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,17 @@ def _make_graphed_callables(
bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
graph_callables = [None for _ in range(len(flatten_sample_args))]

def _returned_param_grad_slots(static_grad_inputs, module_params):
"""Snapshot static grad slots that are consumed through Graphed.backward."""
module_param_start = len(static_grad_inputs) - len(module_params)
return tuple(
idx >= module_param_start
and not getattr(
module_params[idx - module_param_start], "skip_backward_post_hook", False
)
for idx in range(len(static_grad_inputs))
)

# For cases with multiple active RNG states, e.g. TP.
if graph_safe_rng_available():
for _, state in get_all_rng_states().items():
Expand Down Expand Up @@ -569,6 +580,7 @@ def hook_fn(
per_callable_output_unflatten_spec = [None] * len(flatten_sample_args)
per_callable_static_grad_outputs = [None] * len(flatten_sample_args)
per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
per_callable_returned_param_grad_slots = [None] * len(flatten_sample_args)
fwd_idx = [0] * num_model_chunks
bwd_idx = [0] * num_model_chunks
static_grad_outputs_dict = {}
Expand Down Expand Up @@ -716,6 +728,13 @@ def hook_fn(

per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
returned_param_grad_slots = _returned_param_grad_slots(
static_grad_inputs,
per_callable_module_params[per_callable_bwd_idx],
)
per_callable_returned_param_grad_slots[per_callable_bwd_idx] = (
returned_param_grad_slots
)

# Weak ref the static outputs and static grad inputs that are no longer needed
# in the following steps. These two type of tensors are both in cudagraph
Expand All @@ -728,6 +747,18 @@ def hook_fn(
static_outputs
)

# Parameter grads are cloned before being returned from
# Graphed.backward, so their static buffers can be weak-refed now.
static_grad_inputs = per_callable_static_grad_inputs[per_callable_bwd_idx]
per_callable_static_grad_inputs[per_callable_bwd_idx] = tuple(
(
make_weak_ref(grad_input)
if returned_param_grad_slots[idx] and grad_input is not None
else grad_input
)
for idx, grad_input in enumerate(static_grad_inputs)
)

# Weak ref the static grad inputs of the previous backward pass within the
# same chunk.
if previous_per_callable_bwd_idx is not None:
Expand Down Expand Up @@ -769,6 +800,7 @@ def hook_fn(
# Capture backward graphs in reverse order
per_callable_static_grad_outputs = []
per_callable_static_grad_inputs = []
per_callable_returned_param_grad_slots = []
for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip(
reversed(per_callable_static_input_surfaces),
reversed(per_callable_static_outputs),
Expand Down Expand Up @@ -813,10 +845,19 @@ def hook_fn(

per_callable_static_grad_outputs.append(static_grad_outputs)
per_callable_static_grad_inputs.append(static_grad_inputs)
per_callable_returned_param_grad_slots.append(
_returned_param_grad_slots(
static_grad_inputs,
per_callable_module_params[bwd_idx],
)
)

# Reverses the most recent two lists
# Reverse the most recent per-callable lists.
per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
per_callable_returned_param_grad_slots = list(
reversed(per_callable_returned_param_grad_slots)
)
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

def make_graphed_autograd_function(
Expand All @@ -830,6 +871,7 @@ def make_graphed_autograd_function(
static_outputs,
static_grad_outputs,
static_grad_inputs,
returned_param_grad_slots,
):
class Graphed(torch.autograd.Function):
"""Autograd function for graph replay."""
Expand Down Expand Up @@ -911,9 +953,17 @@ def backward(ctx, *grads):
"Expected static_grad_inputs to be a tuple, but got"
f" {type(static_grad_inputs).__name__}"
)
return (None, None, None) + tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)
grad_inputs = []
for idx, grad_input in enumerate(static_grad_inputs):
if grad_input is None:
grad_inputs.append(None)
elif returned_param_grad_slots[idx]:
# Returned parameter grads may be installed directly as param.grad.
# Clone to avoid exposing CUDA graph static buffers to autograd users.
grad_inputs.append(grad_input.detach().clone())
else:
grad_inputs.append(grad_input.detach())
return (None, None, None) + tuple(grad_inputs)

def functionalized(*user_args, **user_kwargs):

Expand Down Expand Up @@ -1008,6 +1058,7 @@ def reset():
per_callable_static_outputs[i],
per_callable_static_grad_outputs[i],
per_callable_static_grad_inputs[i],
per_callable_returned_param_grad_slots[i],
)

func = graph_callables[i]
Expand Down
Loading