77import os
88import sys
99import argparse
10+ from dataclasses import dataclass
1011
1112import transformer_engine .pytorch as te
1213from transformer_engine .common .recipe import (
1819
1920import torch
2021import torch .distributed as dist
22+ from torch .distributed .checkpoint import save , load
23+ from torch .distributed .checkpoint .state_dict import (
24+ StateDictOptions ,
25+ get_state_dict ,
26+ set_state_dict ,
27+ )
28+ from torch .distributed .checkpoint .stateful import Stateful
2129from torch .distributed .tensor import DTensor
2230import torch .nn .functional as F
2331from torch import nn , optim
3038LOCAL_RANK = None
3139
3240
41+ @dataclass
42+ class AppState (Stateful ):
43+ """AppState for FSDP2 checkpoint via Torch DCP.
44+
45+ Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
46+ """
47+
48+ model : torch .nn .Module
49+ optimizer : torch .optim .Optimizer
50+
51+ def state_dict (self ):
52+ """
53+ Get the state dict for the model, optimizer, scheduler, and step.
54+ This factory both retrieves the model state dictionary when saving
55+ checkpoints and initializes a destination for the state read from
56+ DCP checkpoint files when loading checkpoints.
57+ """
58+ model_state_dict , optimizer_state_dict = get_state_dict (self .model , self .optimizer )
59+ for fqn in list (model_state_dict .keys ()):
60+ # Get the model parameter.
61+ model_param = model_state_dict [fqn ]
62+ if isinstance (model_param , DTensor ):
63+ model_param = model_param .to_local ()
64+ if model_param .numel () == 0 and fqn in optimizer_state_dict ["state" ]:
65+ # Empty model parameter. Clear the associated optimizer state
66+ # when initializing the optimizer state upon DCP load, because
67+ # empty optimizer state DTensors are not checkpointed with DCP,
68+ # yet get_state_dict / _init_optim_state produce empty Tensors.
69+ # TransformerEngine uses empty Tensors for dummy Parameters.
70+ optimizer_state_dict ["state" ][fqn ] = {}
71+ if fqn .endswith ("._extra_state" ):
72+ # Evict `_extra_state` quantization data from model checkpoint.
73+ model_state_dict .pop (fqn )
74+ return {
75+ "model" : model_state_dict ,
76+ "optim" : optimizer_state_dict ,
77+ }
78+
79+ def load_state_dict (self , state_dict : dict ):
80+ """
81+ Load the state dict for the model, optimizer, scheduler, and step.
82+ Given the checkpoint-loaded state_dict, set the state of the model,
83+ optimizer, scheduler, step, and epoch to the values in state_dict.
84+ """
85+ set_state_dict (
86+ self .model ,
87+ self .optimizer ,
88+ model_state_dict = state_dict ["model" ],
89+ optim_state_dict = state_dict ["optim" ],
90+ # Non-strict checkpoint loading ignores empty optimizer states,
91+ # skips loading non-FP8 checkpoint weights (e.g. _extra_state).
92+ options = StateDictOptions (strict = False ),
93+ )
94+
95+
3396def dist_print (msg ):
3497 if LOCAL_RANK == 0 :
3598 print (msg )
@@ -82,11 +145,16 @@ def _parse_args(argv=None, namespace=None):
82145 "--sharding-dims" ,
83146 type = int ,
84147 nargs = "+" ,
85- help = 'FSDP/HSDP sharding dimensions ("replicate ", "shard ")' ,
148+ help = 'FSDP/HSDP sharding dimensions ("dp_replicate ", "dp_shard", "tp ")' ,
86149 )
87150 args = parser .parse_args (argv , namespace )
88151 if args .sharding_dims :
89- assert len (args .sharding_dims ) <= 2
152+ assert len (args .sharding_dims ) <= 3
153+ if len (args .sharding_dims ) >= 3 :
154+ # Set the TP size in args.
155+ args .tp_size = args .sharding_dims [2 ]
156+ else :
157+ args .tp_size = 1
90158 return args
91159
92160
@@ -136,11 +204,17 @@ def init_te_model(config):
136204 "params_dtype" : params_dtype ,
137205 }
138206 kwargs ["device" ] = config .device
207+ kwargs ["tp_size" ] = config .tp_size
139208
140209 layer_type = get_te_layer_from_string (config .layer_type )
141210 # We are creating model in a way so that we can test both reshard_after_forward=True/False cases.
142211 # more details below.
143- if layer_type in [te .MultiheadAttention , te .TransformerLayer ]:
212+ if layer_type in [
213+ te .TransformerLayer ,
214+ te .MultiheadAttention ,
215+ te .LayerNormMLP ,
216+ # TODO(@cspades): GroupedLinear testing.
217+ ]:
144218 # For this case, we are creating a model that resemebles production use-cases
145219 # wherein there are mltiple TransformerLayers in the model. And we would need
146220 # to shard each transformer layer. Since each transformer layer is not a root module,
@@ -150,44 +224,102 @@ def init_te_model(config):
150224 kwargs ["fuse_qkv_params" ] = True
151225 if layer_type is te .MultiheadAttention :
152226 kwargs ["input_layernorm" ] = True
227+ # DeviceMesh / DTensor-related model parameter operations!
228+ # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
229+ # If not using meta device initialization, reset_parameters is called during __init__.
230+ if config .tp_size > 1 :
231+ assert "dp_shard" in config .mesh .mesh_dim_names
232+ assert "tp" in config .mesh .mesh_dim_names
233+ dist_print (f"Tensor parallelism activated with size: { config .tp_size } " )
234+ # Activate TP in TE.
235+ kwargs ["set_parallel_mode" ] = True
236+ # For TP shards as DTensors.
237+ kwargs ["tp_mesh" ] = config .mesh ["tp" ]
238+ # For per-tensor quantization recipes with TP.
239+ kwargs ["weight_mesh" ] = config .mesh ["dp_shard" , "tp" ]._flatten ("weight_mesh" )
240+ elif len (config .mesh .mesh_dim_names ) > 1 :
241+ assert "dp_shard" in config .mesh .mesh_dim_names
242+ # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
243+ # Used for per-tensor quantization recipes like Float8CurrentScaling.
244+ kwargs ["weight_mesh" ] = config .mesh ["dp_shard" ] # Only sharding with FSDP.
245+ # Initialize model.
153246 model = nn .Sequential (* [layer_type (* args , ** kwargs ) for _ in range (config .num_layers )])
154- elif layer_type == te .LayerNormLinear :
247+ elif layer_type in [ te .LayerNormLinear , te . Linear ] :
155248 # For this case, we are creating a model with just one LayerNormLinear layer
156249 # so that the model itself is a root module, and FSDP2's fully_shard assigns
157250 # reshard_after_forward=True for the parameters of these model.
158251 args [1 ] *= 3 # QKV projection
159252 out_shape [- 1 ] *= 3
253+ # DeviceMesh / DTensor-related model parameter operations!
254+ # NOTE(@cspades): `set_device_mesh` works, but needs to be called before reset_parameters.
255+ # If not using meta device initialization, reset_parameters is called during __init__.
256+ if config .tp_size > 1 :
257+ assert "dp_shard" in config .mesh .mesh_dim_names
258+ assert "tp" in config .mesh .mesh_dim_names
259+ dist_print (f"Tensor parallelism activated with size: { config .tp_size } " )
260+ # Activate TP in TE.
261+ kwargs ["parallel_mode" ] = "column"
262+ # For TP shards as DTensors.
263+ kwargs ["tp_mesh" ] = config .mesh ["tp" ]
264+ # For per-tensor quantization recipes with TP.
265+ kwargs ["weight_mesh" ] = config .mesh ["dp_shard" , "tp" ]._flatten ("weight_mesh" )
266+ # Modify output shape for column-parallel Linear.
267+ out_shape [- 1 ] //= config .tp_size
268+ elif len (config .mesh .mesh_dim_names ) > 1 :
269+ assert "dp_shard" in config .mesh .mesh_dim_names
270+ # HSDP (DP-Repl, DP-Shard) requires a call to `set_device_mesh(weight_mesh)`.
271+ # Used for per-tensor quantization recipes like Float8CurrentScaling.
272+ kwargs ["weight_mesh" ] = config .mesh ["dp_shard" ] # Only sharding with FSDP.
273+ # Initialize model.
160274 model = layer_type (* args , ** kwargs )
161275 else :
276+ # Other TE module. Just ambiguously initialize it.
162277 model = layer_type (* args , ** kwargs )
163278
164279 return model , inp_shape , out_shape
165280
166281
167282def get_device_mesh (world_size , sharding_dims ):
168- dist_print (f"sharding-dims:{ sharding_dims } " )
283+ dist_print (f"sharding-dims: { sharding_dims } " )
169284 device_ids = list (range (world_size ))
170- if sharding_dims is None : # FSDP
171- mesh = DeviceMesh ("cuda" , device_ids )
172- elif len (sharding_dims ) == 1 :
173- assert sharding_dims [0 ] == world_size
174- mesh = DeviceMesh ("cuda" , device_ids )
175- elif len (sharding_dims ) == 2 : # HSDP
285+ # FSDP
286+ if sharding_dims is None or len (sharding_dims ) == 1 :
287+ assert sharding_dims is None or sharding_dims [0 ] == world_size
288+ mesh = init_device_mesh (
289+ "cuda" ,
290+ (world_size ,),
291+ mesh_dim_names = ("dp_shard" ,),
292+ )
293+ # HSDP
294+ elif len (sharding_dims ) == 2 :
176295 assert sharding_dims [0 ] * sharding_dims [1 ] == world_size
177296 mesh = init_device_mesh (
178297 "cuda" ,
179298 (sharding_dims [0 ], sharding_dims [1 ]),
180- mesh_dim_names = ("replicate" , "shard" ),
299+ mesh_dim_names = ("dp_replicate" , "dp_shard" ),
300+ )
301+ # (H/F)SDP-TP
302+ elif len (sharding_dims ) == 3 :
303+ assert sharding_dims [0 ] * sharding_dims [1 ] * sharding_dims [2 ] == world_size
304+ mesh = init_device_mesh (
305+ "cuda" ,
306+ (sharding_dims [0 ], sharding_dims [1 ], sharding_dims [2 ]),
307+ mesh_dim_names = ("dp_replicate" , "dp_shard" , "tp" ),
181308 )
182309 else :
310+ # Unsupported topology.
183311 assert False
184312 return mesh
185313
186314
187315def shard_model_with_fsdp2 (model , mesh ):
316+ assert "dp_shard" in mesh .mesh_dim_names
317+ dp_dims = (
318+ ("dp_replicate" , "dp_shard" ) if "dp_replicate" in mesh .mesh_dim_names else ("dp_shard" ,)
319+ )
188320 for child in model .children ():
189- fully_shard (child , mesh = mesh )
190- fully_shard (model , mesh = mesh )
321+ fully_shard (child , mesh = mesh [ dp_dims ] )
322+ fully_shard (model , mesh = mesh [ dp_dims ] )
191323 return model
192324
193325
@@ -216,16 +348,18 @@ def restore_custom_attrs(module, custom_attrs):
216348
217349@torch .no_grad ()
218350def test_fp8_fsdp2_allgather (model ):
219- # Do manual allgather in fp32 and match against fp8 allgather done
220- # with fsdp2
351+ """
352+ Compare the result of the FP8 AG by FSDP2 with a manual AG in FP32
353+ after dequantizing the FP8 values.
354+ """
221355 # FP32 manual weight allgather
222356 fp32_allgathered_params = {}
223357 for name , param in model .named_parameters ():
224358 assert isinstance (param , DTensor )
225359 local_tensor = param ._local_tensor
226360 device_mesh = param .device_mesh
227361 dist_group = (
228- device_mesh .get_group (mesh_dim = "shard " )
362+ device_mesh .get_group (mesh_dim = "dp_shard " )
229363 if device_mesh .ndim > 1
230364 else device_mesh .get_group ()
231365 )
@@ -244,6 +378,10 @@ def test_fp8_fsdp2_allgather(model):
244378 module .unshard ()
245379 # Make sure allgathered parameters match exactly
246380 for name , param in model .named_parameters ():
381+ if isinstance (param , DTensor ):
382+ # Will still be a DTensor in the case of TP, even after FSDP2 AG,
383+ # because we wrap our weights as DTensor shards over the TP group.
384+ param = param ._local_tensor
247385 assert torch .allclose (param .dequantize (), fp32_allgathered_params [name ])
248386 # Revert model to original sharded state
249387 for module in model .modules ():
@@ -253,6 +391,9 @@ def test_fp8_fsdp2_allgather(model):
253391
254392
255393def _train (args ):
394+ """
395+ Torch Distributed Initialization
396+ """
256397 global LOCAL_RANK
257398 assert "TORCHELASTIC_RUN_ID" in os .environ
258399 WORLD_RANK = int (os .getenv ("RANK" , "0" ))
@@ -277,10 +418,20 @@ def _train(args):
277418 nccl_world = dist .new_group (backend = "nccl" )
278419 device = torch .device (f"cuda:{ LOCAL_RANK } " )
279420
421+ # Create a DeviceMesh for fully_shard.
422+ world_size = int (WORLD_SIZE )
423+ # Setup the sharding mesh for FSDP/HSDP.
424+ mesh = get_device_mesh (world_size , args .sharding_dims )
425+ args .mesh = mesh
426+
427+ """
428+ TransformerEngine Model Initialization
429+ """
280430 # FP8 Configuration
281431 fp8_format = Format .HYBRID
282432 fp8_recipe = get_recipe_from_string (args .recipe , fp8_format )
283433
434+ # Model initialization context.
284435 build_model_context_args = {}
285436 if not args .fp8_init :
286437 # Build model context (FP8 init)
@@ -301,29 +452,31 @@ def _train(args):
301452 f" { torch .cuda .memory_allocated (device )/ 1e6 } MB"
302453 )
303454
304- # Creating a DeviceMesh for fully_shard
305- world_size = int (WORLD_SIZE )
306- # Setup the sharding mesh for FSDP/HSDP
307- mesh = get_device_mesh (world_size , args .sharding_dims )
455+ # Avoid passing custom attributes to FSDP2.
308456 custom_attrs = save_custom_attrs (model )
457+ # Fully-shard the model. Will convert model parameters into DTensor
458+ # if not already converted by TP.
309459 model = shard_model_with_fsdp2 (model , mesh )
460+ # Restore custom attributes on parameters.
310461 restore_custom_attrs (model , custom_attrs )
311- # model now has DTensors as its parameters
312462
313463 if args .device == "meta" :
314464 # After FSDP2 has been applied, materialize and initialize the sharded parameters
315- # TE base.py's reset_parameters() handles DTensors with FP8 initialization
465+ # TE base.py's reset_parameters() handles DTensors with FP8 initialization.
316466 for module in model .modules ():
317467 if hasattr (module , "reset_parameters" ):
318468 module .reset_parameters ()
319469 dist_print (f" Sharded parameters materialized and initialized on cuda device." )
320470
321471 dist_print (
322- f"FSDP2 model in cuda , memory allocated: { torch .cuda .memory_allocated (device )/ 1e6 } MB"
472+ f"FSDP2 model in CUDA , memory allocated: { torch .cuda .memory_allocated (device )/ 1e6 } MB"
323473 )
324474
325475 optimizer = optim .Adam (model .parameters (), lr = 1e-3 )
326476
477+ """
478+ Pre-Save Training
479+ """
327480 for iteration in range (args .iter ):
328481 # Zero the parameter gradients
329482 optimizer .zero_grad ()
0 commit comments