Skip to content

Commit dbb9d14

Browse files
committed
Add TP to FSDP and HSDP functional tests.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent ccc0abc commit dbb9d14

8 files changed

Lines changed: 204 additions & 34 deletions

File tree

tests/pytorch/distributed/run_fsdp2_model.py

Lines changed: 177 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import sys
99
import argparse
10+
from dataclasses import dataclass
1011

1112
import transformer_engine.pytorch as te
1213
from transformer_engine.common.recipe import (
@@ -18,6 +19,13 @@
1819

1920
import torch
2021
import torch.distributed as dist
22+
from torch.distributed.checkpoint import save, load
23+
from torch.distributed.checkpoint.state_dict import (
24+
StateDictOptions,
25+
get_state_dict,
26+
set_state_dict,
27+
)
28+
from torch.distributed.checkpoint.stateful import Stateful
2129
from torch.distributed.tensor import DTensor
2230
import torch.nn.functional as F
2331
from torch import nn, optim
@@ -30,6 +38,61 @@
3038
LOCAL_RANK = None
3139

3240

41+
@dataclass
42+
class AppState(Stateful):
43+
"""AppState for FSDP2 checkpoint via Torch DCP.
44+
45+
Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
46+
"""
47+
48+
model: torch.nn.Module
49+
optimizer: torch.optim.Optimizer
50+
51+
def state_dict(self):
52+
"""
53+
Get the state dict for the model, optimizer, scheduler, and step.
54+
This factory both retrieves the model state dictionary when saving
55+
checkpoints and initializes a destination for the state read from
56+
DCP checkpoint files when loading checkpoints.
57+
"""
58+
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
59+
for fqn in list(model_state_dict.keys()):
60+
# Get the model parameter.
61+
model_param = model_state_dict[fqn]
62+
if isinstance(model_param, DTensor):
63+
model_param = model_param.to_local()
64+
if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]:
65+
# Empty model parameter. Clear the associated optimizer state
66+
# when initializing the optimizer state upon DCP load, because
67+
# empty optimizer state DTensors are not checkpointed with DCP,
68+
# yet get_state_dict / _init_optim_state produce empty Tensors.
69+
# TransformerEngine uses empty Tensors for dummy Parameters.
70+
optimizer_state_dict["state"][fqn] = {}
71+
if fqn.endswith("._extra_state"):
72+
# Evict `_extra_state` quantization data from model checkpoint.
73+
model_state_dict.pop(fqn)
74+
return {
75+
"model": model_state_dict,
76+
"optim": optimizer_state_dict,
77+
}
78+
79+
def load_state_dict(self, state_dict: dict):
80+
"""
81+
Load the state dict for the model, optimizer, scheduler, and step.
82+
Given the checkpoint-loaded state_dict, set the state of the model,
83+
optimizer, scheduler, step, and epoch to the values in state_dict.
84+
"""
85+
set_state_dict(
86+
self.model,
87+
self.optimizer,
88+
model_state_dict=state_dict["model"],
89+
optim_state_dict=state_dict["optim"],
90+
# Non-strict checkpoint loading ignores empty optimizer states,
91+
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
92+
options=StateDictOptions(strict=False),
93+
)
94+
95+
3396
def dist_print(msg):
3497
if LOCAL_RANK == 0:
3598
print(msg)
@@ -82,11 +145,16 @@ def _parse_args(argv=None, namespace=None):
82145
"--sharding-dims",
83146
type=int,
84147
nargs="+",
85-
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
148+
help='FSDP/HSDP sharding dimensions ("dp_replicate", "dp_shard", "tp")',
86149
)
87150
args = parser.parse_args(argv, namespace)
88151
if args.sharding_dims:
89-
assert len(args.sharding_dims) <= 2
152+
assert len(args.sharding_dims) <= 3
153+
if len(args.sharding_dims) >= 3:
154+
# Set the TP size in args.
155+
args.tp_size = args.sharding_dims[2]
156+
else:
157+
args.tp_size = 1
90158
return args
91159

92160

@@ -136,11 +204,17 @@ def init_te_model(config):
136204
"params_dtype": params_dtype,
137205
}
138206
kwargs["device"] = config.device
207+
kwargs["tp_size"] = config.tp_size
139208

140209
layer_type = get_te_layer_from_string(config.layer_type)
141210
# We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
142211
# more details below.
143-
if layer_type in [te.MultiheadAttention, te.TransformerLayer]:
212+
if layer_type in [
213+
te.TransformerLayer,
214+
te.MultiheadAttention,
215+
te.LayerNormMLP,
216+
# TODO(@cspades): GroupedLinear testing.
217+
]:
144218
# For this case, we are creating a model that resemebles production use-cases
145219
# wherein there are mltiple TransformerLayers in the model. And we would need
146220
# to shard each transformer layer. Since each transformer layer is not a root module,
@@ -150,44 +224,102 @@ def init_te_model(config):
150224
kwargs["fuse_qkv_params"] = True
151225
if layer_type is te.MultiheadAttention:
152226
kwargs["input_layernorm"] = True
227+
# DeviceMesh / DTensor-related model parameter operations!
228+
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
229+
# If not using meta device initialization, reset_parameters is called during __init__.
230+
if config.tp_size > 1:
231+
assert "dp_shard" in config.mesh.mesh_dim_names
232+
assert "tp" in config.mesh.mesh_dim_names
233+
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
234+
# Activate TP in TE.
235+
kwargs["set_parallel_mode"] = True
236+
# For TP shards as DTensors.
237+
kwargs["tp_mesh"] = config.mesh["tp"]
238+
# For per-tensor quantization recipes with TP.
239+
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
240+
elif len(config.mesh.mesh_dim_names) > 1:
241+
assert "dp_shard" in config.mesh.mesh_dim_names
242+
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
243+
# Used for per-tensor quantization recipes like Float8CurrentScaling.
244+
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
245+
# Initialize model.
153246
model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)])
154-
elif layer_type == te.LayerNormLinear:
247+
elif layer_type in [te.LayerNormLinear, te.Linear]:
155248
# For this case, we are creating a model with just one LayerNormLinear layer
156249
# so that the model itself is a root module, and FSDP2's fully_shard assigns
157250
# reshard_after_forward=True for the parameters of these model.
158251
args[1] *= 3 # QKV projection
159252
out_shape[-1] *= 3
253+
# DeviceMesh / DTensor-related model parameter operations!
254+
# NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
255+
# If not using meta device initialization, reset_parameters is called during __init__.
256+
if config.tp_size > 1:
257+
assert "dp_shard" in config.mesh.mesh_dim_names
258+
assert "tp" in config.mesh.mesh_dim_names
259+
dist_print(f"Tensor parallelism activated with size: {config.tp_size}")
260+
# Activate TP in TE.
261+
kwargs["parallel_mode"] = "column"
262+
# For TP shards as DTensors.
263+
kwargs["tp_mesh"] = config.mesh["tp"]
264+
# For per-tensor quantization recipes with TP.
265+
kwargs["weight_mesh"] = config.mesh["dp_shard", "tp"]._flatten("weight_mesh")
266+
# Modify output shape for column-parallel Linear.
267+
out_shape[-1] //= config.tp_size
268+
elif len(config.mesh.mesh_dim_names) > 1:
269+
assert "dp_shard" in config.mesh.mesh_dim_names
270+
# HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
271+
# Used for per-tensor quantization recipes like Float8CurrentScaling.
272+
kwargs["weight_mesh"] = config.mesh["dp_shard"] # Only sharding with FSDP.
273+
# Initialize model.
160274
model = layer_type(*args, **kwargs)
161275
else:
276+
# Other TE module. Just ambiguously initialize it.
162277
model = layer_type(*args, **kwargs)
163278

164279
return model, inp_shape, out_shape
165280

166281

167282
def get_device_mesh(world_size, sharding_dims):
168-
dist_print(f"sharding-dims:{sharding_dims}")
283+
dist_print(f"sharding-dims: {sharding_dims}")
169284
device_ids = list(range(world_size))
170-
if sharding_dims is None: # FSDP
171-
mesh = DeviceMesh("cuda", device_ids)
172-
elif len(sharding_dims) == 1:
173-
assert sharding_dims[0] == world_size
174-
mesh = DeviceMesh("cuda", device_ids)
175-
elif len(sharding_dims) == 2: # HSDP
285+
# FSDP
286+
if sharding_dims is None or len(sharding_dims) == 1:
287+
assert sharding_dims is None or sharding_dims[0] == world_size
288+
mesh = init_device_mesh(
289+
"cuda",
290+
(world_size,),
291+
mesh_dim_names=("dp_shard",),
292+
)
293+
# HSDP
294+
elif len(sharding_dims) == 2:
176295
assert sharding_dims[0] * sharding_dims[1] == world_size
177296
mesh = init_device_mesh(
178297
"cuda",
179298
(sharding_dims[0], sharding_dims[1]),
180-
mesh_dim_names=("replicate", "shard"),
299+
mesh_dim_names=("dp_replicate", "dp_shard"),
300+
)
301+
# (H/F)SDP-TP
302+
elif len(sharding_dims) == 3:
303+
assert sharding_dims[0] * sharding_dims[1] * sharding_dims[2] == world_size
304+
mesh = init_device_mesh(
305+
"cuda",
306+
(sharding_dims[0], sharding_dims[1], sharding_dims[2]),
307+
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
181308
)
182309
else:
310+
# Unsupported topology.
183311
assert False
184312
return mesh
185313

186314

187315
def shard_model_with_fsdp2(model, mesh):
316+
assert "dp_shard" in mesh.mesh_dim_names
317+
dp_dims = (
318+
("dp_replicate", "dp_shard") if "dp_replicate" in mesh.mesh_dim_names else ("dp_shard",)
319+
)
188320
for child in model.children():
189-
fully_shard(child, mesh=mesh)
190-
fully_shard(model, mesh=mesh)
321+
fully_shard(child, mesh=mesh[dp_dims])
322+
fully_shard(model, mesh=mesh[dp_dims])
191323
return model
192324

193325

@@ -216,16 +348,18 @@ def restore_custom_attrs(module, custom_attrs):
216348

217349
@torch.no_grad()
218350
def test_fp8_fsdp2_allgather(model):
219-
# Do manual allgather in fp32 and match against fp8 allgather done
220-
# with fsdp2
351+
"""
352+
Compare the result of the FP8 AG by FSDP2 with a manual AG in FP32
353+
after dequantizing the FP8 values.
354+
"""
221355
# FP32 manual weight allgather
222356
fp32_allgathered_params = {}
223357
for name, param in model.named_parameters():
224358
assert isinstance(param, DTensor)
225359
local_tensor = param._local_tensor
226360
device_mesh = param.device_mesh
227361
dist_group = (
228-
device_mesh.get_group(mesh_dim="shard")
362+
device_mesh.get_group(mesh_dim="dp_shard")
229363
if device_mesh.ndim > 1
230364
else device_mesh.get_group()
231365
)
@@ -244,6 +378,10 @@ def test_fp8_fsdp2_allgather(model):
244378
module.unshard()
245379
# Make sure allgathered parameters match exactly
246380
for name, param in model.named_parameters():
381+
if isinstance(param, DTensor):
382+
# Will still be a DTensor in the case of TP, even after FSDP2 AG,
383+
# because we wrap our weights as DTensor shards over the TP group.
384+
param = param._local_tensor
247385
assert torch.allclose(param.dequantize(), fp32_allgathered_params[name])
248386
# Revert model to original sharded state
249387
for module in model.modules():
@@ -253,6 +391,9 @@ def test_fp8_fsdp2_allgather(model):
253391

254392

255393
def _train(args):
394+
"""
395+
Torch Distributed Initialization
396+
"""
256397
global LOCAL_RANK
257398
assert "TORCHELASTIC_RUN_ID" in os.environ
258399
WORLD_RANK = int(os.getenv("RANK", "0"))
@@ -277,10 +418,20 @@ def _train(args):
277418
nccl_world = dist.new_group(backend="nccl")
278419
device = torch.device(f"cuda:{LOCAL_RANK}")
279420

421+
# Create a DeviceMesh for fully_shard.
422+
world_size = int(WORLD_SIZE)
423+
# Setup the sharding mesh for FSDP/HSDP.
424+
mesh = get_device_mesh(world_size, args.sharding_dims)
425+
args.mesh = mesh
426+
427+
"""
428+
TransformerEngine Model Initialization
429+
"""
280430
# FP8 Configuration
281431
fp8_format = Format.HYBRID
282432
fp8_recipe = get_recipe_from_string(args.recipe, fp8_format)
283433

434+
# Model initialization context.
284435
build_model_context_args = {}
285436
if not args.fp8_init:
286437
# Build model context (FP8 init)
@@ -301,29 +452,31 @@ def _train(args):
301452
f" {torch.cuda.memory_allocated(device)/1e6} MB"
302453
)
303454

304-
# Creating a DeviceMesh for fully_shard
305-
world_size = int(WORLD_SIZE)
306-
# Setup the sharding mesh for FSDP/HSDP
307-
mesh = get_device_mesh(world_size, args.sharding_dims)
455+
# Avoid passing custom attributes to FSDP2.
308456
custom_attrs = save_custom_attrs(model)
457+
# Fully-shard the model. Will convert model parameters into DTensor
458+
# if not already converted by TP.
309459
model = shard_model_with_fsdp2(model, mesh)
460+
# Restore custom attributes on parameters.
310461
restore_custom_attrs(model, custom_attrs)
311-
# model now has DTensors as its parameters
312462

313463
if args.device == "meta":
314464
# After FSDP2 has been applied, materialize and initialize the sharded parameters
315-
# TE base.py's reset_parameters() handles DTensors with FP8 initialization
465+
# TE base.py's reset_parameters() handles DTensors with FP8 initialization.
316466
for module in model.modules():
317467
if hasattr(module, "reset_parameters"):
318468
module.reset_parameters()
319469
dist_print(f" Sharded parameters materialized and initialized on cuda device.")
320470

321471
dist_print(
322-
f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
472+
f"FSDP2 model in CUDA, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB"
323473
)
324474

325475
optimizer = optim.Adam(model.parameters(), lr=1e-3)
326476

477+
"""
478+
Pre-Save Training
479+
"""
327480
for iteration in range(args.iter):
328481
# Zero the parameter gradients
329482
optimizer.zero_grad()

tests/pytorch/distributed/test_torch_fsdp2.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,32 @@
1919
def _run_test(fp_init, sharding_dims, recipe, layer_type):
2020
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
2121
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]
22-
2322
if fp_init:
2423
test_cmd += ["--fp8-init"]
25-
26-
if len(sharding_dims) == 1:
27-
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
28-
elif len(sharding_dims) == 2:
29-
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
30-
else:
31-
assert False
24+
test_cmd += ["--sharding-dims"]
25+
for x in sharding_dims:
26+
test_cmd.append(str(x))
3227
test_cmd += ["--recipe", recipe]
3328
test_cmd += ["--layer-type", layer_type]
34-
3529
result = subprocess.run(test_cmd, env=os.environ, check=True)
3630

3731

3832
@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
3933
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
4034
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
41-
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
35+
@pytest.mark.parametrize(
36+
"sharding_dims",
37+
(
38+
# FSDP
39+
[NUM_PROCS],
40+
# HSDP
41+
[2, NUM_PROCS // 2],
42+
# FSDP-TP
43+
[1, 2, NUM_PROCS // 2],
44+
# HSDP-TP
45+
[NUM_PROCS // 4, 2, 2],
46+
),
47+
)
4248
@pytest.mark.parametrize("fp8_init", (False, True))
4349
@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling"))
4450
@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer"))

transformer_engine/pytorch/attention/multi_head_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class MultiheadAttention(torch.nn.Module):
203203
For example:
204204
- device_mesh["dp"] for FSDP.
205205
- device_mesh["dp_cp"] if using CP ranks in FSDP.
206+
- device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
206207
- device_mesh["tp"] if using TP.
207208
- device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
208209
@@ -643,6 +644,7 @@ def set_device_mesh(
643644
For example:
644645
- device_mesh["dp"] for FSDP.
645646
- device_mesh["dp_cp"] if using CP ranks in FSDP.
647+
- device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
646648
- device_mesh["tp"] if using TP.
647649
- device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
648650
"""

0 commit comments

Comments
 (0)