Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions swift/arguments/sft_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,26 @@ def _init_fsdp(self):
# Extract fsdp_config dict
self.fsdp_config = fsdp_config_dict.get('fsdp_config', {})

# Set FSDP_VERSION environment variable for accelerate to recognize FSDP2
fsdp_version = self.fsdp_config.get('fsdp_version', 2)
os.environ['FSDP_VERSION'] = str(fsdp_version)
# Set FSDP environment variables for accelerate
# Map fsdp_config keys to environment variable names
fsdp_env_mapping = {
'fsdp_version': 'FSDP_VERSION',
'state_dict_type': 'FSDP_STATE_DICT_TYPE',
'reshard_after_forward': 'FSDP_RESHARD_AFTER_FORWARD',
'auto_wrap_policy': 'FSDP_AUTO_WRAP_POLICY',
'cpu_ram_efficient_loading': 'FSDP_CPU_RAM_EFFICIENT_LOADING',
}
for config_key, env_var in fsdp_env_mapping.items():
if config_key in self.fsdp_config:
value = self.fsdp_config[config_key]
# Convert bool to lowercase string format expected by accelerate
if isinstance(value, bool):
value = str(value).lower()
os.environ[env_var] = str(value)
Comment thread
tzteyang marked this conversation as resolved.

# Set default FSDP_VERSION if not specified
if 'FSDP_VERSION' not in os.environ:
os.environ['FSDP_VERSION'] = '2'

# Set environment variable to optimize NCCL memory usage
if 'TORCH_NCCL_AVOID_RECORD_STREAMS' not in os.environ:
Expand Down