Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, computed_field
from safetensors.torch import save_file
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import (
CPUOffloadPolicy,
MixedPrecisionPolicy,
fully_shard,
)
from torch.distributed.tensor import DTensor, Placement, Shard
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
from typing_extensions import NotRequired, Self, TypedDict
Expand Down Expand Up @@ -268,7 +273,40 @@ def fully_shard(
fsdp_config: FSDPConfig,
) -> Self:
"""Fully shard the model parameters."""
raise NotImplementedError
self.fsdp_config = fsdp_config
self.fsdp_mesh = self._init_world_mesh()

if self.fsdp_config.requires_grad:
for name, module in self.named_modules():
# if "ts_model" in name:
# torch.distributed.breakpoint()
for p_name, param in module.named_parameters(recurse=False):
if param.requires_grad:
param_fp32 = torch.nn.Parameter(param.to(dtype=torch.float32))
setattr(module, p_name, param_fp32)
else:
for param in self.parameters():
param.requires_grad = False

mp_policy = MixedPrecisionPolicy(param_dtype=fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype)

for module in self.modules():
if module is self:
continue
if isinstance(module, BaseModel):
module.fully_shard(fsdp_config)

mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
fully_shard(
self,
mesh=self.fsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
return self

def save_hf(self, hf_dir: Path | str, save_dtype: torch.dtype = torch.bfloat16, safetensors_prefix: str = "model"):
with profile_time_and_memory(f"[Saving HF to [{safetensors_prefix}]{hf_dir} cost]"):
Expand Down Expand Up @@ -1470,3 +1508,11 @@ def _mark_dynamic(self, seq_ctx: SequenceContext, dim=0):
"""
torch._dynamo.mark_dynamic(seq_ctx.cu_seq_lens_q, dim)
torch._dynamo.mark_dynamic(seq_ctx.cu_seq_lens_k, dim)

def _init_world_mesh(self):
device = DEVICE
world_size = dist.get_world_size()

# TODO: Support hsdp_sharding_size
fsdp_mesh = init_device_mesh(device, (world_size,))
return fsdp_mesh